diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java index d5032799..79898040 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2023 the original author or authors. + * Copyright 2020-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ import jakarta.servlet.http.HttpServletResponse; import org.springframework.core.log.LogMessage; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; +import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; @@ -37,6 +39,7 @@ import org.springframework.security.oauth2.server.authorization.web.authenticati import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; @@ -66,6 +69,8 @@ public final class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFil private final RequestMatcher tokenRevocationEndpointMatcher; + private AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); + private AuthenticationConverter authenticationConverter; private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendRevocationSuccessResponse; @@ -109,6 +114,11 @@ public final class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFil try { Authentication tokenRevocationAuthentication = this.authenticationConverter.convert(request); + if (tokenRevocationAuthentication instanceof AbstractAuthenticationToken) { + ((AbstractAuthenticationToken) tokenRevocationAuthentication) + .setDetails(this.authenticationDetailsSource.buildDetails(request)); + } + Authentication tokenRevocationAuthenticationResult = this.authenticationManager .authenticate(tokenRevocationAuthentication); this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, @@ -123,6 +133,19 @@ public final class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFil } } + /** + * Sets the {@link AuthenticationDetailsSource} used for building an authentication + * details instance from {@link HttpServletRequest}. + * @param authenticationDetailsSource the {@link AuthenticationDetailsSource} used for + * building an authentication details instance from {@link HttpServletRequest} + * @since 1.4 + */ + public void setAuthenticationDetailsSource( + AuthenticationDetailsSource authenticationDetailsSource) { + Assert.notNull(authenticationDetailsSource, "authenticationDetailsSource cannot be null"); + this.authenticationDetailsSource = authenticationDetailsSource; + } + /** * Sets the {@link AuthenticationConverter} used when attempting to extract a Revoke * Token Request from {@link HttpServletRequest} to an instance of diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java index 6b5bbb6b..f8145320 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2022 the original author or authors. + * Copyright 2020-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,33 +15,23 @@ */ package org.springframework.security.oauth2.server.authorization.web; -import java.time.Duration; -import java.time.Instant; -import java.util.Arrays; -import java.util.HashSet; -import java.util.function.Consumer; - import jakarta.servlet.FilterChain; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; - import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; -import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.*; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.OAuth2TokenType; @@ -52,14 +42,19 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.authentication.WebAuthenticationDetails; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashSet; +import java.util.function.Consumer; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.*; /** * Tests for {@link OAuth2TokenRevocationEndpointFilter}. @@ -102,6 +97,13 @@ public class OAuth2TokenRevocationEndpointFilterTests { .hasMessage("tokenRevocationEndpointUri cannot be empty"); } + @Test + public void setAuthenticationDetailsSourceWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.filter.setAuthenticationDetailsSource(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authenticationDetailsSource cannot be null"); + } + @Test public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.filter.setAuthenticationConverter(null)) @@ -198,6 +200,40 @@ public class OAuth2TokenRevocationEndpointFilterTests { assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value()); } + @Test + public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient, + ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret()); + + MockHttpServletRequest request = createTokenRevocationRequest(); + + AuthenticationDetailsSource authenticationDetailsSource = mock( + AuthenticationDetailsSource.class); + WebAuthenticationDetails webAuthenticationDetails = new WebAuthenticationDetails(request); + given(authenticationDetailsSource.buildDetails(any())).willReturn(webAuthenticationDetails); + this.filter.setAuthenticationDetailsSource(authenticationDetailsSource); + + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", + Instant.now(), Instant.now().plus(Duration.ofHours(1)), + new HashSet<>(Arrays.asList("scope1", "scope2"))); + OAuth2TokenRevocationAuthenticationToken tokenRevocationAuthentication = new OAuth2TokenRevocationAuthenticationToken( + accessToken, clientPrincipal); + + given(this.authenticationManager.authenticate(any())).willReturn(tokenRevocationAuthentication); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(clientPrincipal); + SecurityContextHolder.setContext(securityContext); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(authenticationDetailsSource).buildDetails(any()); + } + @Test public void doFilterWhenCustomAuthenticationConverterThenUsed() throws Exception { RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();