Browse Source

DefaultReactiveOAuth2AuthorizedClientManager defaults ServerWebExchange

Fixes gh-7390
pull/7427/head
Joe Grandja 6 years ago
parent
commit
dcdeab596d
  1. 62
      oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java
  2. 100
      oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java

62
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManager.java

@ -70,35 +70,52 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
String clientRegistrationId = authorizeRequest.getClientRegistrationId(); String clientRegistrationId = authorizeRequest.getClientRegistrationId();
Authentication principal = authorizeRequest.getPrincipal(); Authentication principal = authorizeRequest.getPrincipal();
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
Assert.notNull(serverWebExchange, "serverWebExchange cannot be null");
return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
.switchIfEmpty(Mono.defer(() -> .switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
.flatMap(authorizedClient -> { .flatMap(authorizedClient -> {
// Re-authorize // Re-authorize
return authorizationContext(authorizeRequest, authorizedClient) return authorizationContext(authorizeRequest, authorizedClient)
.flatMap(this.authorizedClientProvider::authorize) .flatMap(this.authorizedClientProvider::authorize)
.doOnNext(reauthorizedClient -> .flatMap(reauthorizedClient -> saveAuthorizedClient(reauthorizedClient, principal, serverWebExchange))
this.authorizedClientRepository.saveAuthorizedClient(
reauthorizedClient, principal, serverWebExchange))
// Default to the existing authorizedClient if the client was not re-authorized // Default to the existing authorizedClient if the client was not re-authorized
.defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ? .defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ?
authorizeRequest.getAuthorizedClient() : authorizedClient); authorizeRequest.getAuthorizedClient() : authorizedClient);
}) })
.switchIfEmpty(Mono.defer(() -> .switchIfEmpty(Mono.deferWithContext(context ->
// Authorize // Authorize
this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException( .switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
"Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) "Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
.flatMap(clientRegistration -> authorizationContext(authorizeRequest, clientRegistration)) .flatMap(clientRegistration -> authorizationContext(authorizeRequest, clientRegistration))
.flatMap(this.authorizedClientProvider::authorize) .flatMap(this.authorizedClientProvider::authorize)
.doOnNext(authorizedClient -> .flatMap(authorizedClient -> saveAuthorizedClient(authorizedClient, principal, serverWebExchange))
this.authorizedClientRepository.saveAuthorizedClient( .subscriberContext(context)
authorizedClient, principal, serverWebExchange)) )
)); );
}
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId, Authentication principal, ServerWebExchange serverWebExchange) {
return Mono.justOrEmpty(serverWebExchange)
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
.flatMap(exchange -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange));
}
private Mono<OAuth2AuthorizedClient> saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal, ServerWebExchange serverWebExchange) {
return Mono.justOrEmpty(serverWebExchange)
.switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
.map(exchange -> {
this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, exchange);
return authorizedClient;
})
.defaultIfEmpty(authorizedClient);
}
private static Mono<ServerWebExchange> currentServerWebExchange() {
return Mono.subscriberContext()
.filter(c -> c.hasKey(ServerWebExchange.class))
.map(c -> c.get(ServerWebExchange.class));
} }
private Mono<OAuth2AuthorizationContext> authorizationContext(OAuth2AuthorizeRequest authorizeRequest, private Mono<OAuth2AuthorizationContext> authorizationContext(OAuth2AuthorizeRequest authorizeRequest,
@ -158,15 +175,20 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
@Override @Override
public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest) { public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest) {
Map<String, Object> contextAttributes = Collections.emptyMap();
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
String scope = serverWebExchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE); return Mono.justOrEmpty(serverWebExchange)
if (StringUtils.hasText(scope)) { .switchIfEmpty(Mono.defer(() -> currentServerWebExchange()))
contextAttributes = new HashMap<>(); .flatMap(exchange -> {
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, Map<String, Object> contextAttributes = Collections.emptyMap();
StringUtils.delimitedListToStringArray(scope, " ")); String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
} if (StringUtils.hasText(scope)) {
return Mono.just(contextAttributes); contextAttributes = new HashMap<>();
contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME,
StringUtils.delimitedListToStringArray(scope, " "));
}
return Mono.just(contextAttributes);
})
.defaultIfEmpty(Collections.emptyMap());
} }
} }
} }

100
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultReactiveOAuth2AuthorizedClientManagerTests.java

@ -34,9 +34,9 @@ import org.springframework.security.oauth2.client.web.server.ServerOAuth2Authori
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange; import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.util.context.Context;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
@ -64,6 +64,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
private Authentication principal; private Authentication principal;
private OAuth2AuthorizedClient authorizedClient; private OAuth2AuthorizedClient authorizedClient;
private MockServerWebExchange serverWebExchange; private MockServerWebExchange serverWebExchange;
private Context context;
private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor; private ArgumentCaptor<OAuth2AuthorizationContext> authorizationContextCaptor;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -75,6 +76,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
this.authorizedClientRepository = mock(ServerOAuth2AuthorizedClientRepository.class); this.authorizedClientRepository = mock(ServerOAuth2AuthorizedClientRepository.class);
when(this.authorizedClientRepository.loadAuthorizedClient( when(this.authorizedClientRepository.loadAuthorizedClient(
anyString(), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty()); anyString(), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty());
when(this.authorizedClientRepository.saveAuthorizedClient(
any(OAuth2AuthorizedClient.class), any(Authentication.class), any(ServerWebExchange.class))).thenReturn(Mono.empty());
this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class); this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class);
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty()); when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty());
this.contextAttributesMapper = mock(Function.class); this.contextAttributesMapper = mock(Function.class);
@ -88,6 +91,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken());
this.serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build(); this.serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build();
this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class); this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class);
} }
@ -119,16 +123,6 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
.hasMessage("contextAttributesMapper cannot be null"); .hasMessage("contextAttributesMapper cannot be null");
} }
@Test
public void authorizeWhenServerWebExchangeIsNullThenThrowIllegalArgumentException() {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build();
assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("serverWebExchange cannot be null");
}
@Test @Test
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> this.authorizedClientManager.authorize(null).block()) assertThatThrownBy(() -> this.authorizedClientManager.authorize(null).block())
@ -140,9 +134,8 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id") OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id")
.principal(this.principal) .principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build(); .build();
assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block())
.isInstanceOf(IllegalArgumentException.class) .isInstanceOf(IllegalArgumentException.class)
.hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); .hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'");
} }
@ -155,9 +148,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal) .principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build(); .build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
.subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@ -168,8 +161,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isNull(); assertThat(authorizedClient).isNull();
verify(this.authorizedClientRepository, never()).saveAuthorizedClient( verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange));
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -177,15 +169,14 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() {
when(this.clientRegistrationRepository.findByRegistrationId( when(this.clientRegistrationRepository.findByRegistrationId(
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
when(this.authorizedClientProvider.authorize( when(this.authorizedClientProvider.authorize(
any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal) .principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build(); .build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
.subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@ -200,6 +191,31 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
eq(this.authorizedClient), eq(this.principal), eq(this.serverWebExchange)); eq(this.authorizedClient), eq(this.principal), eq(this.serverWebExchange));
} }
@Test
public void authorizeWhenNotAuthorizedAndSupportedProviderAndExchangeUnavailableThenAuthorizedButNotSaved() {
when(this.clientRegistrationRepository.findByRegistrationId(
eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration));
when(this.authorizedClientProvider.authorize(
any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isNull();
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(this.authorizedClient);
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test @Test
public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() {
@ -216,9 +232,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal) .principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build(); .build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest)
.subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(any()); verify(this.contextAttributesMapper).apply(any());
@ -241,21 +257,18 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient));
// Set custom contextAttributesMapper capable of mapping the form parameters // Set custom contextAttributesMapper capable of mapping the form parameters
this.authorizedClientManager.setContextAttributesMapper(authorizeRequest -> { this.authorizedClientManager.setContextAttributesMapper(authorizeRequest ->
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); currentServerWebExchange()
return Mono.just(serverWebExchange)
.flatMap(ServerWebExchange::getFormData) .flatMap(ServerWebExchange::getFormData)
.map(formData -> { .map(formData -> {
Map<String, Object> contextAttributes = new HashMap<>(); Map<String, Object> contextAttributes = new HashMap<>();
String username = formData.getFirst(OAuth2ParameterNames.USERNAME); String username = formData.getFirst(OAuth2ParameterNames.USERNAME);
contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username);
String password = formData.getFirst(OAuth2ParameterNames.PASSWORD); String password = formData.getFirst(OAuth2ParameterNames.PASSWORD);
if (StringUtils.hasText(username) && StringUtils.hasText(password)) { contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password);
contextAttributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username);
contextAttributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password);
}
return contextAttributes; return contextAttributes;
}); })
}); );
this.serverWebExchange = MockServerWebExchange.builder( this.serverWebExchange = MockServerWebExchange.builder(
MockServerHttpRequest MockServerHttpRequest
@ -263,12 +276,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
.contentType(MediaType.APPLICATION_FORM_URLENCODED) .contentType(MediaType.APPLICATION_FORM_URLENCODED)
.body("username=username&password=password")) .body("username=username&password=password"))
.build(); .build();
this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal) .principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build(); .build();
this.authorizedClientManager.authorize(authorizeRequest).block(); this.authorizedClientManager.authorize(authorizeRequest).subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
@ -284,9 +297,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal) .principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build(); .build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest)
.subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
@ -297,8 +310,7 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal);
assertThat(authorizedClient).isSameAs(this.authorizedClient); assertThat(authorizedClient).isSameAs(this.authorizedClient);
verify(this.authorizedClientRepository, never()).saveAuthorizedClient( verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.serverWebExchange));
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -312,9 +324,9 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal) .principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build(); .build();
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest)
.subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
@ -346,12 +358,12 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
.get("/") .get("/")
.queryParam(OAuth2ParameterNames.SCOPE, "read write")) .queryParam(OAuth2ParameterNames.SCOPE, "read write"))
.build(); .build();
this.context = Context.of(ServerWebExchange.class, this.serverWebExchange);
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal) .principal(this.principal)
.attribute(ServerWebExchange.class.getName(), this.serverWebExchange)
.build(); .build();
this.authorizedClientManager.authorize(reauthorizeRequest).block(); this.authorizedClientManager.authorize(reauthorizeRequest).subscriberContext(this.context).block();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
@ -359,4 +371,10 @@ public class DefaultReactiveOAuth2AuthorizedClientManagerTests {
String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME);
assertThat(requestScopeAttribute).contains("read", "write"); assertThat(requestScopeAttribute).contains("read", "write");
} }
private Mono<ServerWebExchange> currentServerWebExchange() {
return Mono.subscriberContext()
.filter(c -> c.hasKey(ServerWebExchange.class))
.map(c -> c.get(ServerWebExchange.class));
}
} }

Loading…
Cancel
Save