Add test for refresh_token grant with public client

Related gh-1432
This commit is contained in:
Joe Grandja
2024-01-10 13:28:11 -05:00
parent e76fde806f
commit faad0be153
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2022 the original author or authors.
* Copyright 2020-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,6 +23,8 @@ import java.util.HashSet;
import java.util.List;
import java.util.Set;
import jakarta.servlet.http.HttpServletRequest;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
@@ -34,6 +36,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
@@ -43,16 +46,25 @@ import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import org.springframework.lang.Nullable;
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.Transient;
import org.springframework.security.crypto.password.NoOpPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -66,6 +78,7 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository.RegisteredClientParametersMapper;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@@ -77,10 +90,15 @@ import org.springframework.security.oauth2.server.authorization.test.SpringTestC
import org.springframework.security.oauth2.server.authorization.test.SpringTestContextExtension;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.AuthenticationConverter;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.containsString;
@@ -217,6 +235,32 @@ public class OAuth2RefreshTokenGrantTests {
assertThat(accessToken.isActive()).isTrue();
}
// gh-1430
@Test
public void requestWhenRefreshTokenRequestWithPublicClientThenReturnAccessTokenResponse() throws Exception {
this.spring.register(AuthorizationServerConfigurationWithPublicClientAuthentication.class).autowire();
RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient()
.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
.build();
this.registeredClientRepository.save(registeredClient);
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
this.authorizationService.save(authorization);
this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
.params(getRefreshTokenRequestParameters(authorization))
.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()))
.andExpect(status().isOk())
.andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store")))
.andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache")))
.andExpect(jsonPath("$.access_token").isNotEmpty())
.andExpect(jsonPath("$.token_type").isNotEmpty())
.andExpect(jsonPath("$.expires_in").isNotEmpty())
.andExpect(jsonPath("$.refresh_token").isNotEmpty())
.andExpect(jsonPath("$.scope").isNotEmpty());
}
private static MultiValueMap<String, String> getRefreshTokenRequestParameters(OAuth2Authorization authorization) {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue());
@@ -307,4 +351,119 @@ public class OAuth2RefreshTokenGrantTests {
}
}
@EnableWebSecurity
@Configuration(proxyBeanMethods = false)
static class AuthorizationServerConfigurationWithPublicClientAuthentication extends AuthorizationServerConfiguration {
// @formatter:off
@Bean
SecurityFilterChain authorizationServerSecurityFilterChain(
HttpSecurity http, RegisteredClientRepository registeredClientRepository) throws Exception {
OAuth2AuthorizationServerConfigurer authorizationServerConfigurer =
new OAuth2AuthorizationServerConfigurer();
authorizationServerConfigurer
.clientAuthentication(clientAuthentication ->
clientAuthentication
.authenticationConverter(
new PublicClientRefreshTokenAuthenticationConverter())
.authenticationProvider(
new PublicClientRefreshTokenAuthenticationProvider(registeredClientRepository))
);
RequestMatcher endpointsMatcher = authorizationServerConfigurer.getEndpointsMatcher();
http
.securityMatcher(endpointsMatcher)
.authorizeHttpRequests(authorize ->
authorize.anyRequest().authenticated()
)
.csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher))
.apply(authorizationServerConfigurer);
return http.build();
}
// @formatter:on
}
@Transient
private static final class PublicClientRefreshTokenAuthenticationToken extends OAuth2ClientAuthenticationToken {
private PublicClientRefreshTokenAuthenticationToken(String clientId) {
super(clientId, ClientAuthenticationMethod.NONE, null, null);
}
private PublicClientRefreshTokenAuthenticationToken(RegisteredClient registeredClient) {
super(registeredClient, ClientAuthenticationMethod.NONE, null);
}
}
private static final class PublicClientRefreshTokenAuthenticationConverter implements AuthenticationConverter {
@Nullable
@Override
public Authentication convert(HttpServletRequest request) {
// grant_type (REQUIRED)
String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
if (!AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(grantType)) {
return null;
}
// client_id (REQUIRED)
String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
if (!StringUtils.hasText(clientId)) {
return null;
}
return new PublicClientRefreshTokenAuthenticationToken(clientId);
}
}
private static final class PublicClientRefreshTokenAuthenticationProvider implements AuthenticationProvider {
private final RegisteredClientRepository registeredClientRepository;
private PublicClientRefreshTokenAuthenticationProvider(RegisteredClientRepository registeredClientRepository) {
Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
this.registeredClientRepository = registeredClientRepository;
}
@Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException {
PublicClientRefreshTokenAuthenticationToken publicClientAuthentication =
(PublicClientRefreshTokenAuthenticationToken) authentication;
if (!ClientAuthenticationMethod.NONE.equals(publicClientAuthentication.getClientAuthenticationMethod())) {
return null;
}
String clientId = publicClientAuthentication.getPrincipal().toString();
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
}
if (!registeredClient.getClientAuthenticationMethods().contains(
publicClientAuthentication.getClientAuthenticationMethod())) {
throwInvalidClient("authentication_method");
}
return new PublicClientRefreshTokenAuthenticationToken(registeredClient);
}
@Override
public boolean supports(Class<?> authentication) {
return PublicClientRefreshTokenAuthenticationToken.class.isAssignableFrom(authentication);
}
private static void throwInvalidClient(String parameterName) {
OAuth2Error error = new OAuth2Error(
OAuth2ErrorCodes.INVALID_CLIENT,
"Public client authentication failed: " + parameterName,
null
);
throw new OAuth2AuthenticationException(error);
}
}
}