diff --git a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java index b7ef307509..b9b48f443f 100644 --- a/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java +++ b/oauth2/oauth2-core/src/main/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequest.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.core.endpoint; +import java.io.Serial; import java.io.Serializable; import java.net.URI; import java.nio.charset.StandardCharsets; @@ -51,31 +52,46 @@ import org.springframework.web.util.UriUtils; * "https://tools.ietf.org/html/rfc6749#section-4.1.1">Section 4.1.1 Authorization Code * Grant Request */ -public final class OAuth2AuthorizationRequest implements Serializable { +public class OAuth2AuthorizationRequest implements Serializable { + @Serial private static final long serialVersionUID = 620L; - private String authorizationUri; + private final String authorizationUri; - private AuthorizationGrantType authorizationGrantType; + private final AuthorizationGrantType authorizationGrantType; - private OAuth2AuthorizationResponseType responseType; + private final OAuth2AuthorizationResponseType responseType; - private String clientId; + private final String clientId; - private String redirectUri; + private final String redirectUri; - private Set scopes; + private final Set scopes; - private String state; + private final String state; - private Map additionalParameters; + private final Map additionalParameters; - private String authorizationRequestUri; + private final String authorizationRequestUri; - private Map attributes; + private final Map attributes; - private OAuth2AuthorizationRequest() { + protected OAuth2AuthorizationRequest(AbstractBuilder builder) { + Assert.hasText(builder.authorizationUri, "authorizationUri cannot be empty"); + Assert.hasText(builder.clientId, "clientId cannot be empty"); + this.authorizationUri = builder.authorizationUri; + this.authorizationGrantType = builder.authorizationGrantType; + this.responseType = builder.responseType; + this.clientId = builder.clientId; + this.redirectUri = builder.redirectUri; + this.scopes = Collections.unmodifiableSet( + CollectionUtils.isEmpty(builder.scopes) ? Collections.emptySet() : new LinkedHashSet<>(builder.scopes)); + this.state = builder.state; + this.additionalParameters = Collections.unmodifiableMap(builder.additionalParameters); + this.authorizationRequestUri = StringUtils.hasText(builder.authorizationRequestUri) + ? builder.authorizationRequestUri : builder.buildAuthorizationRequestUri(); + this.attributes = Collections.unmodifiableMap(builder.attributes); } /** @@ -185,7 +201,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { * @return the {@link Builder} */ public static Builder authorizationCode() { - return new Builder(AuthorizationGrantType.AUTHORIZATION_CODE); + return new Builder(); } @Override @@ -226,7 +242,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { public static Builder from(OAuth2AuthorizationRequest authorizationRequest) { Assert.notNull(authorizationRequest, "authorizationRequest cannot be null"); // @formatter:off - return new Builder(authorizationRequest.getGrantType()) + return new Builder() .authorizationUri(authorizationRequest.getAuthorizationUri()) .clientId(authorizationRequest.getClientId()) .redirectUri(authorizationRequest.getRedirectUri()) @@ -240,13 +256,32 @@ public final class OAuth2AuthorizationRequest implements Serializable { /** * A builder for {@link OAuth2AuthorizationRequest}. */ - public static final class Builder { + public static class Builder extends AbstractBuilder { + + /** + * Builds a new {@link OAuth2AuthorizationRequest}. + * @return a {@link OAuth2AuthorizationRequest} + */ + @Override + public OAuth2AuthorizationRequest build() { + return new OAuth2AuthorizationRequest(this); + } + + } + + /** + * A builder for subclasses of {@link OAuth2AuthorizationRequest}. + * + * @param the type of authorization request + * @param the type of the builder + */ + protected abstract static class AbstractBuilder> { private String authorizationUri; - private AuthorizationGrantType authorizationGrantType; + private final AuthorizationGrantType authorizationGrantType = AuthorizationGrantType.AUTHORIZATION_CODE; - private OAuth2AuthorizationResponseType responseType; + private final OAuth2AuthorizationResponseType responseType = OAuth2AuthorizationResponseType.CODE; private String clientId; @@ -269,12 +304,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { private final DefaultUriBuilderFactory uriBuilderFactory; - private Builder(AuthorizationGrantType authorizationGrantType) { - Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null"); - this.authorizationGrantType = authorizationGrantType; - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationGrantType)) { - this.responseType = OAuth2AuthorizationResponseType.CODE; - } + protected AbstractBuilder() { this.uriBuilderFactory = new DefaultUriBuilderFactory(); // The supplied authorizationUri may contain encoded parameters // so disable encoding in UriBuilder and instead apply encoding within this @@ -282,78 +312,85 @@ public final class OAuth2AuthorizationRequest implements Serializable { this.uriBuilderFactory.setEncodingMode(DefaultUriBuilderFactory.EncodingMode.NONE); } + @SuppressWarnings("unchecked") + protected final B getThis() { + // avoid unchecked casts in subclasses by using "getThis()" instead of "(B) + // this" + return (B) this; + } + /** * Sets the uri for the authorization endpoint. * @param authorizationUri the uri for the authorization endpoint - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder authorizationUri(String authorizationUri) { + public B authorizationUri(String authorizationUri) { this.authorizationUri = authorizationUri; - return this; + return getThis(); } /** * Sets the client identifier. * @param clientId the client identifier - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder clientId(String clientId) { + public B clientId(String clientId) { this.clientId = clientId; - return this; + return getThis(); } /** * Sets the uri for the redirection endpoint. * @param redirectUri the uri for the redirection endpoint - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder redirectUri(String redirectUri) { + public B redirectUri(String redirectUri) { this.redirectUri = redirectUri; - return this; + return getThis(); } /** * Sets the scope(s). * @param scope the scope(s) - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder scope(String... scope) { + public B scope(String... scope) { if (scope != null && scope.length > 0) { return scopes(new LinkedHashSet<>(Arrays.asList(scope))); } - return this; + return getThis(); } /** * Sets the scope(s). * @param scopes the scope(s) - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder scopes(Set scopes) { + public B scopes(Set scopes) { this.scopes = scopes; - return this; + return getThis(); } /** * Sets the state. * @param state the state - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder state(String state) { + public B state(String state) { this.state = state; - return this; + return getThis(); } /** * Sets the additional parameter(s) used in the request. * @param additionalParameters the additional parameter(s) used in the request - * @return the {@link Builder} + * @return the {@link AbstractBuilder} */ - public Builder additionalParameters(Map additionalParameters) { + public B additionalParameters(Map additionalParameters) { if (!CollectionUtils.isEmpty(additionalParameters)) { this.additionalParameters.putAll(additionalParameters); } - return this; + return getThis(); } /** @@ -361,52 +398,55 @@ public final class OAuth2AuthorizationRequest implements Serializable { * allowing the ability to add, replace, or remove. * @param additionalParametersConsumer a {@code Consumer} of the additional * parameters + * @return the {@link AbstractBuilder} * @since 5.3 */ - public Builder additionalParameters(Consumer> additionalParametersConsumer) { + public B additionalParameters(Consumer> additionalParametersConsumer) { if (additionalParametersConsumer != null) { additionalParametersConsumer.accept(this.additionalParameters); } - return this; + return getThis(); } /** * A {@code Consumer} to be provided access to all the parameters allowing the * ability to add, replace, or remove. * @param parametersConsumer a {@code Consumer} of all the parameters + * @return the {@link AbstractBuilder} * @since 5.3 */ - public Builder parameters(Consumer> parametersConsumer) { + public B parameters(Consumer> parametersConsumer) { if (parametersConsumer != null) { this.parametersConsumer = parametersConsumer; } - return this; + return getThis(); } /** * Sets the attributes associated to the request. * @param attributes the attributes associated to the request - * @return the {@link Builder} + * @return the {@link AbstractBuilder} * @since 5.2 */ - public Builder attributes(Map attributes) { + public B attributes(Map attributes) { if (!CollectionUtils.isEmpty(attributes)) { this.attributes.putAll(attributes); } - return this; + return getThis(); } /** * A {@code Consumer} to be provided access to the attribute(s) allowing the * ability to add, replace, or remove. * @param attributesConsumer a {@code Consumer} of the attribute(s) + * @return the {@link AbstractBuilder} * @since 5.3 */ - public Builder attributes(Consumer> attributesConsumer) { + public B attributes(Consumer> attributesConsumer) { if (attributesConsumer != null) { attributesConsumer.accept(this.attributes); } - return this; + return getThis(); } /** @@ -418,12 +458,12 @@ public final class OAuth2AuthorizationRequest implements Serializable { * {@code application/x-www-form-urlencoded} MIME format. * @param authorizationRequestUri the {@code URI} string representation of the * OAuth 2.0 Authorization Request - * @return the {@link Builder} + * @return the {@link AbstractBuilder} * @since 5.1 */ - public Builder authorizationRequestUri(String authorizationRequestUri) { + public B authorizationRequestUri(String authorizationRequestUri) { this.authorizationRequestUri = authorizationRequestUri; - return this; + return getThis(); } /** @@ -431,37 +471,17 @@ public final class OAuth2AuthorizationRequest implements Serializable { * OAuth 2.0 Authorization Request allowing for further customizations. * @param authorizationRequestUriFunction a {@code Function} to be provided a * {@code UriBuilder} representation of the OAuth 2.0 Authorization Request + * @return the {@link AbstractBuilder} * @since 5.3 */ - public Builder authorizationRequestUri(Function authorizationRequestUriFunction) { + public B authorizationRequestUri(Function authorizationRequestUriFunction) { if (authorizationRequestUriFunction != null) { this.authorizationRequestUriFunction = authorizationRequestUriFunction; } - return this; + return getThis(); } - /** - * Builds a new {@link OAuth2AuthorizationRequest}. - * @return a {@link OAuth2AuthorizationRequest} - */ - public OAuth2AuthorizationRequest build() { - Assert.hasText(this.authorizationUri, "authorizationUri cannot be empty"); - Assert.hasText(this.clientId, "clientId cannot be empty"); - OAuth2AuthorizationRequest authorizationRequest = new OAuth2AuthorizationRequest(); - authorizationRequest.authorizationUri = this.authorizationUri; - authorizationRequest.authorizationGrantType = this.authorizationGrantType; - authorizationRequest.responseType = this.responseType; - authorizationRequest.clientId = this.clientId; - authorizationRequest.redirectUri = this.redirectUri; - authorizationRequest.state = this.state; - authorizationRequest.scopes = Collections.unmodifiableSet( - CollectionUtils.isEmpty(this.scopes) ? Collections.emptySet() : new LinkedHashSet<>(this.scopes)); - authorizationRequest.additionalParameters = Collections.unmodifiableMap(this.additionalParameters); - authorizationRequest.attributes = Collections.unmodifiableMap(this.attributes); - authorizationRequest.authorizationRequestUri = StringUtils.hasText(this.authorizationRequestUri) - ? this.authorizationRequestUri : this.buildAuthorizationRequestUri(); - return authorizationRequest; - } + public abstract T build(); private String buildAuthorizationRequestUri() { Map parameters = getParameters(); // Not encoded @@ -486,7 +506,7 @@ public final class OAuth2AuthorizationRequest implements Serializable { return this.authorizationRequestUriFunction.apply(uriBuilder).toString(); } - private Map getParameters() { + protected Map getParameters() { Map parameters = new LinkedHashMap<>(); parameters.put(OAuth2ParameterNames.RESPONSE_TYPE, this.responseType.getValue()); parameters.put(OAuth2ParameterNames.CLIENT_ID, this.clientId); diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java index 272f2daf5c..6df6b4ca8c 100644 --- a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/OAuth2AuthorizationRequestTests.java @@ -388,4 +388,40 @@ public class OAuth2AuthorizationRequestTests { assertThat(authorizationRequest1HashCode).isEqualTo(authorizationRequest2HashCode); } + @Test + public void buildWhenExtendedTypeAndAllValuesProvidedThenAllValuesAreSet() { + Map additionalParameters = new HashMap<>(); + additionalParameters.put("param1", "value1"); + additionalParameters.put("param2", "value2"); + Map attributes = new HashMap<>(); + attributes.put("attribute1", "value1"); + attributes.put("attribute2", "value2"); + // @formatter:off + TestOidcAuthorizationRequest oidcAuthorizationRequest = TestOidcAuthorizationRequest.builder() + .authorizationUri(AUTHORIZATION_URI) + .clientId(CLIENT_ID) + .redirectUri(REDIRECT_URI) + .scopes(SCOPES) + .state(STATE) + .additionalParameters(additionalParameters) + .attributes(attributes) + .nonce("nonce1234") + .build(); + // @formatter:on + assertThat(oidcAuthorizationRequest.getAuthorizationUri()).isEqualTo(AUTHORIZATION_URI); + assertThat(oidcAuthorizationRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE); + assertThat(oidcAuthorizationRequest.getResponseType()).isEqualTo(OAuth2AuthorizationResponseType.CODE); + assertThat(oidcAuthorizationRequest.getClientId()).isEqualTo(CLIENT_ID); + assertThat(oidcAuthorizationRequest.getRedirectUri()).isEqualTo(REDIRECT_URI); + assertThat(oidcAuthorizationRequest.getScopes()).isEqualTo(SCOPES); + assertThat(oidcAuthorizationRequest.getState()).isEqualTo(STATE); + assertThat(oidcAuthorizationRequest.getAdditionalParameters()).isEqualTo(additionalParameters); + assertThat(oidcAuthorizationRequest.getAttributes()).isEqualTo(attributes); + assertThat(oidcAuthorizationRequest.getNonce()).isEqualTo("nonce1234"); + assertThat(oidcAuthorizationRequest.getAuthorizationRequestUri()) + .isEqualTo("https://provider.com/oauth2/authorize?" + "response_type=code&client_id=client-id&" + + "scope=scope1%20scope2&state=state&" + + "redirect_uri=https://example.com¶m1=value1¶m2=value2&nonce=nonce1234"); + } + } diff --git a/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOidcAuthorizationRequest.java b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOidcAuthorizationRequest.java new file mode 100644 index 0000000000..54b14a2231 --- /dev/null +++ b/oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/endpoint/TestOidcAuthorizationRequest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2004-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.core.endpoint; + +import java.util.Map; + +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; + +/** + * @author Joe Grandja + */ +public class TestOidcAuthorizationRequest extends OAuth2AuthorizationRequest { + + private final String nonce; + + protected TestOidcAuthorizationRequest(Builder builder) { + super(builder); + this.nonce = builder.nonce; + } + + public String getNonce() { + return this.nonce; + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder extends AbstractBuilder { + + private String nonce; + + public Builder nonce(String nonce) { + this.nonce = nonce; + return this; + } + + @Override + public TestOidcAuthorizationRequest build() { + return new TestOidcAuthorizationRequest(this); + } + + @Override + protected Map getParameters() { + Map parameters = super.getParameters(); + if (this.nonce != null) { + parameters.put(OidcParameterNames.NONCE, this.nonce); + } + return parameters; + } + + } + +}