Browse Source

Merge 48add5d641 into 8f30567b83

pull/18888/merge
Evgeniy Cheban 1 week ago committed by GitHub
parent
commit
57c6b035bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 47
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java
  2. 48
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java
  3. 39
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java
  4. 41
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

47
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

@ -118,8 +118,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @@ -118,8 +118,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
"anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_USER"));
private final Mono<Authentication> currentAuthenticationMono = ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
.mapNotNull(SecurityContext::getAuthentication);
// @formatter:off
private final Mono<String> clientRegistrationIdMono = this.currentAuthenticationMono
@ -144,6 +143,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @@ -144,6 +143,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository();
private PrincipalResolver principalResolver = (request) -> this.currentAuthenticationMono;
/**
* Constructs a {@code ServerOAuth2AuthorizedClientExchangeFilterFunction} using the
* provided parameters.
@ -326,6 +327,15 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @@ -326,6 +327,15 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
// @formatter:off
return this.principalResolver.resolve(request)
.defaultIfEmpty(ANONYMOUS_USER_TOKEN)
.flatMap((authentication) -> doFilter(request, next)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication)));
// @formatter:on
}
private Mono<ClientResponse> doFilter(ClientRequest request, ExchangeFunction next) {
// @formatter:off
return authorizedClient(request)
.map((authorizedClient) -> bearer(request, authorizedClient))
@ -477,6 +487,18 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @@ -477,6 +487,18 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
this.serverSecurityContextRepository = serverSecurityContextRepository;
}
/**
* Sets the strategy for resolving a {@link Mono} of the {@link Authentication
* principal} from an intercepted request.
* @param principalResolver the strategy for resolving a {@link Mono} of the
* {@link Authentication principal}
* @since 7.1
*/
public void setPrincipalResolver(PrincipalResolver principalResolver) {
Assert.notNull(principalResolver, "principalResolver cannot be null");
this.principalResolver = principalResolver;
}
@FunctionalInterface
private interface ClientResponseHandler {
@ -484,6 +506,27 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @@ -484,6 +506,27 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
}
/**
* A strategy for resolving a {@link Mono} of the {@link Authentication principal}
* from an intercepted request.
*
* @since 7.1
*/
@FunctionalInterface
public interface PrincipalResolver {
/**
* Resolve a {@link Mono} of the {@link Authentication principal} from the current
* request, which is used to obtain an {@link OAuth2AuthorizedClient}.
* @param request the intercepted request, containing HTTP method, URI, headers,
* and request attributes
* @return the {@link Mono} of the {@link Authentication principal} to be used for
* resolving an {@link OAuth2AuthorizedClient}
*/
Mono<Authentication> resolve(ClientRequest request);
}
/**
* Forwards authentication and authorization failures to a
* {@link ReactiveOAuth2AuthorizationFailureHandler}.

48
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

@ -27,6 +27,7 @@ import java.util.stream.Stream; @@ -27,6 +27,7 @@ import java.util.stream.Stream;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.jspecify.annotations.Nullable;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.util.context.Context;
@ -122,6 +123,7 @@ import org.springframework.web.reactive.function.client.WebClientResponseExcepti @@ -122,6 +123,7 @@ import org.springframework.web.reactive.function.client.WebClientResponseExcepti
* @author Rob Winch
* @author Joe Grandja
* @author Roman Matiushchenko
* @author Evgeniy Cheban
* @since 5.1
* @see OAuth2AuthorizedClientManager
* @see DefaultOAuth2AuthorizedClientManager
@ -151,6 +153,13 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement @@ -151,6 +153,13 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
.getContextHolderStrategy();
/*
* For consistency, the default implementation resolves a principal from request
* attributes. Request attributes are populated from Reactor context which is enriched
* in SecurityReactorContextConfiguration.SecurityReactorContextSubscriber
*/
private PrincipalResolver principalResolver = (request) -> getAuthentication(request.attributes());
private OAuth2AuthorizedClientManager authorizedClientManager;
private boolean defaultOAuth2AuthorizedClient;
@ -372,6 +381,18 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement @@ -372,6 +381,18 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
this.clientResponseHandler = new AuthorizationFailureForwarder(authorizationFailureHandler);
}
/**
* Sets the strategy for resolving a {@link Authentication principal} from an
* intercepted request.
* @param principalResolver the strategy for resolving a {@link Authentication
* principal}
* @since 7.1
*/
public void setPrincipalResolver(PrincipalResolver principalResolver) {
Assert.notNull(principalResolver, "principalResolver cannot be null");
this.principalResolver = principalResolver;
}
@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
// @formatter:off
@ -459,7 +480,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement @@ -459,7 +480,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
if (clientRegistrationId == null) {
clientRegistrationId = this.defaultClientRegistrationId;
}
Authentication authentication = getAuthentication(attrs);
Authentication authentication = this.principalResolver.resolve(request);
if (clientRegistrationId == null && this.defaultOAuth2AuthorizedClient
&& authentication instanceof OAuth2AuthenticationToken) {
clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId();
@ -472,7 +493,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement @@ -472,7 +493,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
return Mono.empty();
}
Map<String, Object> attrs = request.attributes();
Authentication authentication = getAuthentication(attrs);
Authentication authentication = this.principalResolver.resolve(request);
if (authentication == null) {
authentication = ANONYMOUS_AUTHENTICATION;
}
@ -495,7 +516,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement @@ -495,7 +516,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
return Mono.just(authorizedClient);
}
Map<String, Object> attrs = request.attributes();
Authentication authentication = getAuthentication(attrs);
Authentication authentication = this.principalResolver.resolve(request);
if (authentication == null) {
authentication = createAuthentication(authorizedClient.getPrincipalName());
}
@ -567,6 +588,27 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement @@ -567,6 +588,27 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
};
}
/**
* A strategy for resolving a {@link Authentication principal} from an intercepted
* request.
*
* @since 7.1
*/
@FunctionalInterface
public interface PrincipalResolver {
/**
* Resolve a {@link Authentication principal} from the current request, which is
* used to obtain an {@link OAuth2AuthorizedClient}.
* @param request the intercepted request, containing HTTP method, URI, headers,
* and request attributes
* @return the {@link Mono} of the {@link Authentication principal} to be used for
* resolving an {@link OAuth2AuthorizedClient}
*/
@Nullable Authentication resolve(ClientRequest request);
}
@FunctionalInterface
private interface ClientResponseHandler {

39
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@ -218,6 +218,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @@ -218,6 +218,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.setServerSecurityContextRepository(null));
}
@Test
public void setPrincipalResolverWhenResolverIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager)
.setPrincipalResolver(null));
}
@Test
public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();
@ -791,6 +798,38 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @@ -791,6 +798,38 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
assertThat(getBody(request0)).isEmpty();
}
@Test
public void filterWhenClientRegistrationIdFromAuthenticationAndCustomPrincipalResolverThenAuthorizedClientResolved() {
this.function.setDefaultOAuth2AuthorizedClient(true);
OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"),
Collections.singletonMap("user", "rob"), "user");
OAuth2AuthenticationToken initialAuthentication = new OAuth2AuthenticationToken(user, user.getAuthorities(),
"initial-registration-id");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(),
this.registration.getRegistrationId());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken);
given(this.authorizedClientRepository.loadAuthorizedClient(this.registration.getRegistrationId(),
authentication, this.serverWebExchange))
.willReturn(Mono.just(authorizedClient));
final ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.build();
this.function.setPrincipalResolver((request) -> Mono.just(authentication));
this.function.filter(clientRequest, this.exchange)
.contextWrite(ReactiveSecurityContextHolder.withAuthentication(initialAuthentication))
.contextWrite(serverWebExchange())
.block();
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request0 = requests.get(0);
assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request0.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request0)).isEmpty();
verify(this.authorizedClientRepository).loadAuthorizedClient(this.registration.getRegistrationId(),
authentication, this.serverWebExchange);
}
@Test
public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() {
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();

41
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@ -125,6 +125,7 @@ import static org.springframework.test.web.client.response.MockRestResponseCreat @@ -125,6 +125,7 @@ import static org.springframework.test.web.client.response.MockRestResponseCreat
/**
* @author Rob Winch
* @author Evgeniy Cheban
* @since 5.1
*/
@ExtendWith(MockitoExtension.class)
@ -217,6 +218,13 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @@ -217,6 +218,13 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
.isThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(null));
}
@Test
public void setPrincipalResolverWhenResolverIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientManager)
.setPrincipalResolver(null));
}
@Test
public void defaultRequestRequestResponseWhenNullRequestContextThenRequestAndResponseNull() {
Map<String, Object> attrs = getDefaultRequestAttributes();
@ -620,6 +628,39 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @@ -620,6 +628,39 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
assertThat(getBody(request)).isEmpty();
}
@Test
public void filterWhenClientRegistrationIdFromAuthenticationAndCustomPrincipalResolverThenAuthorizedClientResolved() {
this.function.setDefaultOAuth2AuthorizedClient(true);
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
OAuth2User user = mock(OAuth2User.class);
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
OAuth2AuthenticationToken initialAuthentication = new OAuth2AuthenticationToken(user, authorities,
"initial-registration-id");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, authorities,
this.registration.getRegistrationId());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken);
given(this.authorizedClientRepository.loadAuthorizedClient(this.registration.getRegistrationId(),
initialAuthentication, servletRequest))
.willReturn(authorizedClient);
final ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.build();
this.function.setPrincipalResolver((request) -> authentication);
this.function.filter(clientRequest, this.exchange)
.contextWrite(context(servletRequest, servletResponse, initialAuthentication))
.block();
List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1);
ClientRequest request = requests.get(0);
assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
assertThat(request.method()).isEqualTo(HttpMethod.GET);
assertThat(getBody(request)).isEmpty();
verify(this.authorizedClientRepository).loadAuthorizedClient(this.registration.getRegistrationId(),
authentication, servletRequest);
}
@Test
public void filterWhenUnauthorizedThenInvokeFailureHandler() {
assertHttpStatusInvokesFailureHandler(HttpStatus.UNAUTHORIZED, OAuth2ErrorCodes.INVALID_TOKEN);

Loading…
Cancel
Save