Browse Source

Allow OAuth2AuthorizationRequest to be extended

Closes gh-18049
pull/18050/head
Joe Grandja 4 months ago
parent
commit
fbf7bb3be1
  1. 182
      oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java
  2. 36
      oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java
  3. 68
      oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOidcAuthorizationRequest.java

182
oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java

@ -16,6 +16,7 @@ @@ -16,6 +16,7 @@
package org.springframework.security.oauth2.core.endpoint;
import java.io.Serial;
import java.io.Serializable;
import java.net.URI;
import java.nio.charset.StandardCharsets;
@ -51,31 +52,46 @@ import org.springframework.web.util.UriUtils; @@ -51,31 +52,46 @@ import org.springframework.web.util.UriUtils;
* "https://tools.ietf.org/html/rfc6749#section-4.1.1">Section 4.1.1 Authorization Code
* Grant Request</a>
*/
public final class OAuth2AuthorizationRequest implements Serializable {
public class OAuth2AuthorizationRequest implements Serializable {
@Serial
private static final long serialVersionUID = 620L;
private String authorizationUri;
private final String authorizationUri;
private AuthorizationGrantType authorizationGrantType;
private final AuthorizationGrantType authorizationGrantType;
private OAuth2AuthorizationResponseType responseType;
private final OAuth2AuthorizationResponseType responseType;
private String clientId;
private final String clientId;
private String redirectUri;
private final String redirectUri;
private Set<String> scopes;
private final Set<String> scopes;
private String state;
private final String state;
private Map<String, Object> additionalParameters;
private final Map<String, Object> additionalParameters;
private String authorizationRequestUri;
private final String authorizationRequestUri;
private Map<String, Object> attributes;
private final Map<String, Object> attributes;
private OAuth2AuthorizationRequest() {
protected OAuth2AuthorizationRequest(AbstractBuilder<?, ?> builder) {
Assert.hasText(builder.authorizationUri, "authorizationUri cannot be empty");
Assert.hasText(builder.clientId, "clientId cannot be empty");
this.authorizationUri = builder.authorizationUri;
this.authorizationGrantType = builder.authorizationGrantType;
this.responseType = builder.responseType;
this.clientId = builder.clientId;
this.redirectUri = builder.redirectUri;
this.scopes = Collections.unmodifiableSet(
CollectionUtils.isEmpty(builder.scopes) ? Collections.emptySet() : new LinkedHashSet<>(builder.scopes));
this.state = builder.state;
this.additionalParameters = Collections.unmodifiableMap(builder.additionalParameters);
this.authorizationRequestUri = StringUtils.hasText(builder.authorizationRequestUri)
? builder.authorizationRequestUri : builder.buildAuthorizationRequestUri();
this.attributes = Collections.unmodifiableMap(builder.attributes);
}
/**
@ -185,7 +201,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { @@ -185,7 +201,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
* @return the {@link Builder}
*/
public static Builder authorizationCode() {
return new Builder(AuthorizationGrantType.AUTHORIZATION_CODE);
return new Builder();
}
@Override
@ -226,7 +242,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { @@ -226,7 +242,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
public static Builder from(OAuth2AuthorizationRequest authorizationRequest) {
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
// @formatter:off
return new Builder(authorizationRequest.getGrantType())
return new Builder()
.authorizationUri(authorizationRequest.getAuthorizationUri())
.clientId(authorizationRequest.getClientId())
.redirectUri(authorizationRequest.getRedirectUri())
@ -240,13 +256,32 @@ public final class OAuth2AuthorizationRequest implements Serializable { @@ -240,13 +256,32 @@ public final class OAuth2AuthorizationRequest implements Serializable {
/**
* A builder for {@link OAuth2AuthorizationRequest}.
*/
public static final class Builder {
public static class Builder extends AbstractBuilder<OAuth2AuthorizationRequest, Builder> {
/**
* Builds a new {@link OAuth2AuthorizationRequest}.
* @return a {@link OAuth2AuthorizationRequest}
*/
@Override
public OAuth2AuthorizationRequest build() {
return new OAuth2AuthorizationRequest(this);
}
}
/**
* A builder for subclasses of {@link OAuth2AuthorizationRequest}.
*
* @param <T> the type of authorization request
* @param <B> the type of the builder
*/
protected abstract static class AbstractBuilder<T extends OAuth2AuthorizationRequest, B extends AbstractBuilder<T, B>> {
private String authorizationUri;
private AuthorizationGrantType authorizationGrantType;
private final AuthorizationGrantType authorizationGrantType = AuthorizationGrantType.AUTHORIZATION_CODE;
private OAuth2AuthorizationResponseType responseType;
private final OAuth2AuthorizationResponseType responseType = OAuth2AuthorizationResponseType.CODE;
private String clientId;
@ -269,12 +304,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { @@ -269,12 +304,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
private final DefaultUriBuilderFactory uriBuilderFactory;
private Builder(AuthorizationGrantType authorizationGrantType) {
Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
this.authorizationGrantType = authorizationGrantType;
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationGrantType)) {
this.responseType = OAuth2AuthorizationResponseType.CODE;
}
protected AbstractBuilder() {
this.uriBuilderFactory = new DefaultUriBuilderFactory();
// The supplied authorizationUri may contain encoded parameters
// so disable encoding in UriBuilder and instead apply encoding within this
@ -282,78 +312,85 @@ public final class OAuth2AuthorizationRequest implements Serializable { @@ -282,78 +312,85 @@ public final class OAuth2AuthorizationRequest implements Serializable {
this.uriBuilderFactory.setEncodingMode(DefaultUriBuilderFactory.EncodingMode.NONE);
}
@SuppressWarnings("unchecked")
protected final B getThis() {
// avoid unchecked casts in subclasses by using "getThis()" instead of "(B)
// this"
return (B) this;
}
/**
* Sets the uri for the authorization endpoint.
* @param authorizationUri the uri for the authorization endpoint
* @return the {@link Builder}
* @return the {@link AbstractBuilder}
*/
public Builder authorizationUri(String authorizationUri) {
public B authorizationUri(String authorizationUri) {
this.authorizationUri = authorizationUri;
return this;
return getThis();
}
/**
* Sets the client identifier.
* @param clientId the client identifier
* @return the {@link Builder}
* @return the {@link AbstractBuilder}
*/
public Builder clientId(String clientId) {
public B clientId(String clientId) {
this.clientId = clientId;
return this;
return getThis();
}
/**
* Sets the uri for the redirection endpoint.
* @param redirectUri the uri for the redirection endpoint
* @return the {@link Builder}
* @return the {@link AbstractBuilder}
*/
public Builder redirectUri(String redirectUri) {
public B redirectUri(String redirectUri) {
this.redirectUri = redirectUri;
return this;
return getThis();
}
/**
* Sets the scope(s).
* @param scope the scope(s)
* @return the {@link Builder}
* @return the {@link AbstractBuilder}
*/
public Builder scope(String... scope) {
public B scope(String... scope) {
if (scope != null && scope.length > 0) {
return scopes(new LinkedHashSet<>(Arrays.asList(scope)));
}
return this;
return getThis();
}
/**
* Sets the scope(s).
* @param scopes the scope(s)
* @return the {@link Builder}
* @return the {@link AbstractBuilder}
*/
public Builder scopes(Set<String> scopes) {
public B scopes(Set<String> scopes) {
this.scopes = scopes;
return this;
return getThis();
}
/**
* Sets the state.
* @param state the state
* @return the {@link Builder}
* @return the {@link AbstractBuilder}
*/
public Builder state(String state) {
public B state(String state) {
this.state = state;
return this;
return getThis();
}
/**
* Sets the additional parameter(s) used in the request.
* @param additionalParameters the additional parameter(s) used in the request
* @return the {@link Builder}
* @return the {@link AbstractBuilder}
*/
public Builder additionalParameters(Map<String, Object> additionalParameters) {
public B additionalParameters(Map<String, Object> additionalParameters) {
if (!CollectionUtils.isEmpty(additionalParameters)) {
this.additionalParameters.putAll(additionalParameters);
}
return this;
return getThis();
}
/**
@ -361,52 +398,55 @@ public final class OAuth2AuthorizationRequest implements Serializable { @@ -361,52 +398,55 @@ public final class OAuth2AuthorizationRequest implements Serializable {
* allowing the ability to add, replace, or remove.
* @param additionalParametersConsumer a {@code Consumer} of the additional
* parameters
* @return the {@link AbstractBuilder}
* @since 5.3
*/
public Builder additionalParameters(Consumer<Map<String, Object>> additionalParametersConsumer) {
public B additionalParameters(Consumer<Map<String, Object>> additionalParametersConsumer) {
if (additionalParametersConsumer != null) {
additionalParametersConsumer.accept(this.additionalParameters);
}
return this;
return getThis();
}
/**
* A {@code Consumer} to be provided access to all the parameters allowing the
* ability to add, replace, or remove.
* @param parametersConsumer a {@code Consumer} of all the parameters
* @return the {@link AbstractBuilder}
* @since 5.3
*/
public Builder parameters(Consumer<Map<String, Object>> parametersConsumer) {
public B parameters(Consumer<Map<String, Object>> parametersConsumer) {
if (parametersConsumer != null) {
this.parametersConsumer = parametersConsumer;
}
return this;
return getThis();
}
/**
* Sets the attributes associated to the request.
* @param attributes the attributes associated to the request
* @return the {@link Builder}
* @return the {@link AbstractBuilder}
* @since 5.2
*/
public Builder attributes(Map<String, Object> attributes) {
public B attributes(Map<String, Object> attributes) {
if (!CollectionUtils.isEmpty(attributes)) {
this.attributes.putAll(attributes);
}
return this;
return getThis();
}
/**
* A {@code Consumer} to be provided access to the attribute(s) allowing the
* ability to add, replace, or remove.
* @param attributesConsumer a {@code Consumer} of the attribute(s)
* @return the {@link AbstractBuilder}
* @since 5.3
*/
public Builder attributes(Consumer<Map<String, Object>> attributesConsumer) {
public B attributes(Consumer<Map<String, Object>> attributesConsumer) {
if (attributesConsumer != null) {
attributesConsumer.accept(this.attributes);
}
return this;
return getThis();
}
/**
@ -418,12 +458,12 @@ public final class OAuth2AuthorizationRequest implements Serializable { @@ -418,12 +458,12 @@ public final class OAuth2AuthorizationRequest implements Serializable {
* {@code application/x-www-form-urlencoded} MIME format.
* @param authorizationRequestUri the {@code URI} string representation of the
* OAuth 2.0 Authorization Request
* @return the {@link Builder}
* @return the {@link AbstractBuilder}
* @since 5.1
*/
public Builder authorizationRequestUri(String authorizationRequestUri) {
public B authorizationRequestUri(String authorizationRequestUri) {
this.authorizationRequestUri = authorizationRequestUri;
return this;
return getThis();
}
/**
@ -431,37 +471,17 @@ public final class OAuth2AuthorizationRequest implements Serializable { @@ -431,37 +471,17 @@ public final class OAuth2AuthorizationRequest implements Serializable {
* OAuth 2.0 Authorization Request allowing for further customizations.
* @param authorizationRequestUriFunction a {@code Function} to be provided a
* {@code UriBuilder} representation of the OAuth 2.0 Authorization Request
* @return the {@link AbstractBuilder}
* @since 5.3
*/
public Builder authorizationRequestUri(Function<UriBuilder, URI> authorizationRequestUriFunction) {
public B authorizationRequestUri(Function<UriBuilder, URI> authorizationRequestUriFunction) {
if (authorizationRequestUriFunction != null) {
this.authorizationRequestUriFunction = authorizationRequestUriFunction;
}
return this;
return getThis();
}
/**
* Builds a new {@link OAuth2AuthorizationRequest}.
* @return a {@link OAuth2AuthorizationRequest}
*/
public OAuth2AuthorizationRequest build() {
Assert.hasText(this.authorizationUri, "authorizationUri cannot be empty");
Assert.hasText(this.clientId, "clientId cannot be empty");
OAuth2AuthorizationRequest authorizationRequest = new OAuth2AuthorizationRequest();
authorizationRequest.authorizationUri = this.authorizationUri;
authorizationRequest.authorizationGrantType = this.authorizationGrantType;
authorizationRequest.responseType = this.responseType;
authorizationRequest.clientId = this.clientId;
authorizationRequest.redirectUri = this.redirectUri;
authorizationRequest.state = this.state;
authorizationRequest.scopes = Collections.unmodifiableSet(
CollectionUtils.isEmpty(this.scopes) ? Collections.emptySet() : new LinkedHashSet<>(this.scopes));
authorizationRequest.additionalParameters = Collections.unmodifiableMap(this.additionalParameters);
authorizationRequest.attributes = Collections.unmodifiableMap(this.attributes);
authorizationRequest.authorizationRequestUri = StringUtils.hasText(this.authorizationRequestUri)
? this.authorizationRequestUri : this.buildAuthorizationRequestUri();
return authorizationRequest;
}
public abstract T build();
private String buildAuthorizationRequestUri() {
Map<String, Object> parameters = getParameters(); // Not encoded
@ -486,7 +506,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { @@ -486,7 +506,7 @@ public final class OAuth2AuthorizationRequest implements Serializable {
return this.authorizationRequestUriFunction.apply(uriBuilder).toString();
}
private Map<String, Object> getParameters() {
protected Map<String, Object> getParameters() {
Map<String, Object> parameters = new LinkedHashMap<>();
parameters.put(OAuth2ParameterNames.RESPONSE_TYPE, this.responseType.getValue());
parameters.put(OAuth2ParameterNames.CLIENT_ID, this.clientId);

36
oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java

@ -388,4 +388,40 @@ public class OAuth2AuthorizationRequestTests { @@ -388,4 +388,40 @@ public class OAuth2AuthorizationRequestTests {
assertThat(authorizationRequest1HashCode).isEqualTo(authorizationRequest2HashCode);
}
@Test
public void buildWhenExtendedTypeAndAllValuesProvidedThenAllValuesAreSet() {
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");
Map<String, Object> attributes = new HashMap<>();
attributes.put("attribute1", "value1");
attributes.put("attribute2", "value2");
// @formatter:off
TestOidcAuthorizationRequest oidcAuthorizationRequest = TestOidcAuthorizationRequest.builder()
.authorizationUri(AUTHORIZATION_URI)
.clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI)
.scopes(SCOPES)
.state(STATE)
.additionalParameters(additionalParameters)
.attributes(attributes)
.nonce("nonce1234")
.build();
// @formatter:on
assertThat(oidcAuthorizationRequest.getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI);
assertThat(oidcAuthorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
assertThat(oidcAuthorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE);
assertThat(oidcAuthorizationRequest.getClientId()).isEqualTo(CLIENT_ID);
assertThat(oidcAuthorizationRequest.getRedirectUri()).isEqualTo(REDIRECT_URI);
assertThat(oidcAuthorizationRequest.getScopes()).isEqualTo(SCOPES);
assertThat(oidcAuthorizationRequest.getState()).isEqualTo(STATE);
assertThat(oidcAuthorizationRequest.getAdditionalParameters()).isEqualTo(additionalParameters);
assertThat(oidcAuthorizationRequest.getAttributes()).isEqualTo(attributes);
assertThat(oidcAuthorizationRequest.getNonce()).isEqualTo("nonce1234");
assertThat(oidcAuthorizationRequest.getAuthorizationRequestUri())
.isEqualTo("https://provider.com/oauth2/authorize?" + "response_type=code&client_id=client-id&"
+ "scope=scope1%20scope2&state=state&"
+ "redirect_uri=https://example.com&param1=value1&param2=value2&nonce=nonce1234");
}
}

68
oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOidcAuthorizationRequest.java

@ -0,0 +1,68 @@ @@ -0,0 +1,68 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.core.endpoint;
import java.util.Map;
import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
/**
* @author Joe Grandja
*/
public class TestOidcAuthorizationRequest extends OAuth2AuthorizationRequest {
private final String nonce;
protected TestOidcAuthorizationRequest(Builder builder) {
super(builder);
this.nonce = builder.nonce;
}
public String getNonce() {
return this.nonce;
}
public static Builder builder() {
return new Builder();
}
public static class Builder extends AbstractBuilder<TestOidcAuthorizationRequest, Builder> {
private String nonce;
public Builder nonce(String nonce) {
this.nonce = nonce;
return this;
}
@Override
public TestOidcAuthorizationRequest build() {
return new TestOidcAuthorizationRequest(this);
}
@Override
protected Map<String, Object> getParameters() {
Map<String, Object> parameters = super.getParameters();
if (this.nonce != null) {
parameters.put(OidcParameterNames.NONCE, this.nonce);
}
return parameters;
}
}
}
Loading…
Cancel
Save