diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index 6a60fb977a..b14ad35dc6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -118,8 +118,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_USER")); private final Mono currentAuthenticationMono = ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .defaultIfEmpty(ANONYMOUS_USER_TOKEN); + .mapNotNull(SecurityContext::getAuthentication); // @formatter:off private final Mono clientRegistrationIdMono = this.currentAuthenticationMono @@ -144,6 +143,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); + private PrincipalResolver principalResolver = (request) -> this.currentAuthenticationMono; + /** * Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the * provided parameters. @@ -326,6 +327,15 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @Override public Mono filter(ClientRequest request, ExchangeFunction next) { + // @formatter:off + return this.principalResolver.resolve(request) + .defaultIfEmpty(ANONYMOUS_USER_TOKEN) + .flatMap((authentication) -> doFilter(request, next) + .contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication))); + // @formatter:on + } + + private Mono doFilter(ClientRequest request, ExchangeFunction next) { // @formatter:off return authorizedClient(request) .map((authorizedClient) -> bearer(request, authorizedClient)) @@ -477,6 +487,18 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements this.serverSecurityContextRepository = serverSecurityContextRepository; } + /** + * Sets the strategy for resolving a {@link Mono} of the {@link Authentication + * principal} from an intercepted request. + * @param principalResolver the strategy for resolving a {@link Mono} of the + * {@link Authentication principal} + * @since 7.1 + */ + public void setPrincipalResolver(PrincipalResolver principalResolver) { + Assert.notNull(principalResolver, "principalResolver cannot be null"); + this.principalResolver = principalResolver; + } + @FunctionalInterface private interface ClientResponseHandler { @@ -484,6 +506,27 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements } + /** + * A strategy for resolving a {@link Mono} of the {@link Authentication principal} + * from an intercepted request. + * + * @since 7.1 + */ + @FunctionalInterface + public interface PrincipalResolver { + + /** + * Resolve a {@link Mono} of the {@link Authentication principal} from the current + * request, which is used to obtain an {@link OAuth2AuthorizedClient}. + * @param request the intercepted request, containing HTTP method, URI, headers, + * and request attributes + * @return the {@link Mono} of the {@link Authentication principal} to be used for + * resolving an {@link OAuth2AuthorizedClient} + */ + Mono resolve(ClientRequest request); + + } + /** * Forwards authentication and authorization failures to a * {@link ReactiveOAuth2AuthorizationFailureHandler}. diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 6f7f8d3a12..6ad3eea021 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -27,6 +27,7 @@ import java.util.stream.Stream; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import org.jspecify.annotations.Nullable; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import reactor.util.context.Context; @@ -122,6 +123,7 @@ import org.springframework.web.reactive.function.client.WebClientResponseExcepti * @author Rob Winch * @author Joe Grandja * @author Roman Matiushchenko + * @author Evgeniy Cheban * @since 5.1 * @see OAuth2AuthorizedClientManager * @see DefaultOAuth2AuthorizedClientManager @@ -151,6 +153,13 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder .getContextHolderStrategy(); + /* + * For consistency, the default implementation resolves a principal from request + * attributes. Request attributes are populated from Reactor context which is enriched + * in SecurityReactorContextConfiguration.SecurityReactorContextSubscriber + */ + private PrincipalResolver principalResolver = (request) -> getAuthentication(request.attributes()); + private OAuth2AuthorizedClientManager authorizedClientManager; private boolean defaultOAuth2AuthorizedClient; @@ -372,6 +381,18 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler); } + /** + * Sets the strategy for resolving a {@link Authentication principal} from an + * intercepted request. + * @param principalResolver the strategy for resolving a {@link Authentication + * principal} + * @since 7.1 + */ + public void setPrincipalResolver(PrincipalResolver principalResolver) { + Assert.notNull(principalResolver, "principalResolver cannot be null"); + this.principalResolver = principalResolver; + } + @Override public Mono filter(ClientRequest request, ExchangeFunction next) { // @formatter:off @@ -459,7 +480,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement if (clientRegistrationId == null) { clientRegistrationId = this.defaultClientRegistrationId; } - Authentication authentication = getAuthentication(attrs); + Authentication authentication = this.principalResolver.resolve(request); if (clientRegistrationId == null && this.defaultOAuth2AuthorizedClient && authentication instanceof OAuth2AuthenticationToken) { clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId(); @@ -472,7 +493,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement return Mono.empty(); } Map attrs = request.attributes(); - Authentication authentication = getAuthentication(attrs); + Authentication authentication = this.principalResolver.resolve(request); if (authentication == null) { authentication = ANONYMOUS_AUTHENTICATION; } @@ -495,7 +516,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement return Mono.just(authorizedClient); } Map attrs = request.attributes(); - Authentication authentication = getAuthentication(attrs); + Authentication authentication = this.principalResolver.resolve(request); if (authentication == null) { authentication = createAuthentication(authorizedClient.getPrincipalName()); } @@ -567,6 +588,27 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement }; } + /** + * A strategy for resolving a {@link Authentication principal} from an intercepted + * request. + * + * @since 7.1 + */ + @FunctionalInterface + public interface PrincipalResolver { + + /** + * Resolve a {@link Authentication principal} from the current request, which is + * used to obtain an {@link OAuth2AuthorizedClient}. + * @param request the intercepted request, containing HTTP method, URI, headers, + * and request attributes + * @return the {@link Mono} of the {@link Authentication principal} to be used for + * resolving an {@link OAuth2AuthorizedClient} + */ + @Nullable Authentication resolve(ClientRequest request); + + } + @FunctionalInterface private interface ClientResponseHandler { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index e2046e4af1..eb34f31546 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -218,6 +218,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { .setServerSecurityContextRepository(null)); } + @Test + public void setPrincipalResolverWhenResolverIsNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager) + .setPrincipalResolver(null)); + } + @Test public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() { ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); @@ -791,6 +798,38 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { assertThat(getBody(request0)).isEmpty(); } + @Test + public void filterWhenClientRegistrationIdFromAuthenticationAndCustomPrincipalResolverThenAuthorizedClientResolved() { + this.function.setDefaultOAuth2AuthorizedClient(true); + OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), + Collections.singletonMap("user", "rob"), "user"); + OAuth2AuthenticationToken initialAuthentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), + "initial-registration-id"); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), + this.registration.getRegistrationId()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); + given(this.authorizedClientRepository.loadAuthorizedClient(this.registration.getRegistrationId(), + authentication, this.serverWebExchange)) + .willReturn(Mono.just(authorizedClient)); + final ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .build(); + this.function.setPrincipalResolver((request) -> Mono.just(authentication)); + this.function.filter(clientRequest, this.exchange) + .contextWrite(ReactiveSecurityContextHolder.withAuthentication(initialAuthentication)) + .contextWrite(serverWebExchange()) + .block(); + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + ClientRequest request0 = requests.get(0); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); + verify(this.authorizedClientRepository).loadAuthorizedClient(this.registration.getRegistrationId(), + authentication, this.serverWebExchange); + } + @Test public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() { ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index bdb76b7335..423705420c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -125,6 +125,7 @@ import static org.springframework.test.web.client.response.MockRestResponseCreat /** * @author Rob Winch + * @author Evgeniy Cheban * @since 5.1 */ @ExtendWith(MockitoExtension.class) @@ -217,6 +218,13 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { .isThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(null)); } + @Test + public void setPrincipalResolverWhenResolverIsNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager) + .setPrincipalResolver(null)); + } + @Test public void defaultRequestRequestResponseWhenNullRequestContextThenRequestAndResponseNull() { Map attrs = getDefaultRequestAttributes(); @@ -620,6 +628,39 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { assertThat(getBody(request)).isEmpty(); } + @Test + public void filterWhenClientRegistrationIdFromAuthenticationAndCustomPrincipalResolverThenAuthorizedClientResolved() { + this.function.setDefaultOAuth2AuthorizedClient(true); + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + OAuth2User user = mock(OAuth2User.class); + List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); + OAuth2AuthenticationToken initialAuthentication = new OAuth2AuthenticationToken(user, authorities, + "initial-registration-id"); + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, authorities, + this.registration.getRegistrationId()); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", + this.accessToken); + given(this.authorizedClientRepository.loadAuthorizedClient(this.registration.getRegistrationId(), + initialAuthentication, servletRequest)) + .willReturn(authorizedClient); + final ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) + .build(); + this.function.setPrincipalResolver((request) -> authentication); + this.function.filter(clientRequest, this.exchange) + .contextWrite(context(servletRequest, servletResponse, initialAuthentication)) + .block(); + List requests = this.exchange.getRequests(); + assertThat(requests).hasSize(1); + ClientRequest request = requests.get(0); + assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); + assertThat(request.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request)).isEmpty(); + verify(this.authorizedClientRepository).loadAuthorizedClient(this.registration.getRegistrationId(), + authentication, servletRequest); + } + @Test public void filterWhenUnauthorizedThenInvokeFailureHandler() { assertHttpStatusInvokesFailureHandler(HttpStatus.UNAUTHORIZED, OAuth2ErrorCodes.INVALID_TOKEN);