diff --git a/config/src/test/java/org/springframework/security/config/SecurityNamespaceHandlerTests.java b/config/src/test/java/org/springframework/security/config/SecurityNamespaceHandlerTests.java index 318bcd90f3..38f9f30861 100644 --- a/config/src/test/java/org/springframework/security/config/SecurityNamespaceHandlerTests.java +++ b/config/src/test/java/org/springframework/security/config/SecurityNamespaceHandlerTests.java @@ -17,9 +17,7 @@ package org.springframework.security.config; import org.apache.commons.logging.Log; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; @@ -48,9 +46,6 @@ import static org.mockito.Mockito.verifyZeroInteractions; @PowerMockIgnore({ "org.w3c.dom.*", "org.xml.sax.*", "org.apache.xerces.*", "javax.xml.parsers.*" }) public class SecurityNamespaceHandlerTests { - @Rule - public ExpectedException thrown = ExpectedException.none(); - // @formatter:off private static final String XML_AUTHENTICATION_MANAGER = "" + " " @@ -103,12 +98,12 @@ public class SecurityNamespaceHandlerTests { @Test public void filterNoClassDefFoundError() throws Exception { String className = "javax.servlet.Filter"; - this.thrown.expect(BeanDefinitionParsingException.class); - this.thrown.expectMessage("NoClassDefFoundError: " + className); PowerMockito.spy(ClassUtils.class); PowerMockito.doThrow(new NoClassDefFoundError(className)).when(ClassUtils.class, "forName", eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class)); - new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK)) + .withMessageContaining("NoClassDefFoundError: " + className); } @Test @@ -124,12 +119,12 @@ public class SecurityNamespaceHandlerTests { @Test public void filterChainProxyClassNotFoundException() throws Exception { String className = FILTER_CHAIN_PROXY_CLASSNAME; - this.thrown.expect(BeanDefinitionParsingException.class); - this.thrown.expectMessage("ClassNotFoundException: " + className); PowerMockito.spy(ClassUtils.class); PowerMockito.doThrow(new ClassNotFoundException(className)).when(ClassUtils.class, "forName", eq(FILTER_CHAIN_PROXY_CLASSNAME), any(ClassLoader.class)); - new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK); + assertThatExceptionOfType(BeanDefinitionParsingException.class) + .isThrownBy(() -> new InMemoryXmlApplicationContext(XML_AUTHENTICATION_MANAGER + XML_HTTP_BLOCK)) + .withMessageContaining("ClassNotFoundException: " + className); } @Test diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfigurationTests.java index db95afd135..5182cd9508 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/GlobalMethodSecurityConfigurationTests.java @@ -25,7 +25,6 @@ import javax.sql.DataSource; import org.aopalliance.intercept.MethodInterceptor; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.springframework.beans.BeansException; @@ -80,9 +79,6 @@ public class GlobalMethodSecurityConfigurationTests { @Rule public final SpringTestRule spring = new SpringTestRule(); - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Autowired(required = false) private MethodSecurityService service; @@ -98,8 +94,8 @@ public class GlobalMethodSecurityConfigurationTests { @Test public void configureWhenGlobalMethodSecurityIsMissingMetadataSourceThenException() { - this.thrown.expect(UnsatisfiedDependencyException.class); - this.spring.register(IllegalStateGlobalMethodSecurityConfig.class).autowire(); + assertThatExceptionOfType(UnsatisfiedDependencyException.class) + .isThrownBy(() -> this.spring.register(IllegalStateGlobalMethodSecurityConfig.class).autowire()); } @Test diff --git a/crypto/src/test/java/org/springframework/security/crypto/codec/HexTests.java b/crypto/src/test/java/org/springframework/security/crypto/codec/HexTests.java index ce02345c92..9c31fcd753 100644 --- a/crypto/src/test/java/org/springframework/security/crypto/codec/HexTests.java +++ b/crypto/src/test/java/org/springframework/security/crypto/codec/HexTests.java @@ -16,11 +16,10 @@ package org.springframework.security.crypto.codec; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Test cases for {@link Hex}. @@ -29,9 +28,6 @@ import static org.assertj.core.api.Assertions.assertThat; */ public class HexTests { - @Rule - public ExpectedException expectedException = ExpectedException.none(); - @Test public void encode() { assertThat(Hex.encode(new byte[] { (byte) 'A', (byte) 'B', (byte) 'C', (byte) 'D' })) @@ -55,30 +51,26 @@ public class HexTests { @Test public void decodeNotEven() { - this.expectedException.expect(IllegalArgumentException.class); - this.expectedException.expectMessage("Hex-encoded string must have an even number of characters"); - Hex.decode("414243444"); + assertThatIllegalArgumentException().isThrownBy(() -> Hex.decode("414243444")) + .withMessage("Hex-encoded string must have an even number of characters"); } @Test public void decodeExistNonHexCharAtFirst() { - this.expectedException.expect(IllegalArgumentException.class); - this.expectedException.expectMessage("Detected a Non-hex character at 1 or 2 position"); - Hex.decode("G0"); + assertThatIllegalArgumentException().isThrownBy(() -> Hex.decode("G0")) + .withMessage("Detected a Non-hex character at 1 or 2 position"); } @Test public void decodeExistNonHexCharAtSecond() { - this.expectedException.expect(IllegalArgumentException.class); - this.expectedException.expectMessage("Detected a Non-hex character at 3 or 4 position"); - Hex.decode("410G"); + assertThatIllegalArgumentException().isThrownBy(() -> Hex.decode("410G")) + .withMessage("Detected a Non-hex character at 3 or 4 position"); } @Test public void decodeExistNonHexCharAtBoth() { - this.expectedException.expect(IllegalArgumentException.class); - this.expectedException.expectMessage("Detected a Non-hex character at 5 or 6 position"); - Hex.decode("4142GG"); + assertThatIllegalArgumentException().isThrownBy(() -> Hex.decode("4142GG")) + .withMessage("Detected a Non-hex character at 5 or 6 position"); } } diff --git a/itest/ldap/embedded-ldap-none/src/integration-test/java/org/springframework/security/LdapServerBeanDefinitionParserTests.java b/itest/ldap/embedded-ldap-none/src/integration-test/java/org/springframework/security/LdapServerBeanDefinitionParserTests.java index ce6677bd68..acb872697d 100644 --- a/itest/ldap/embedded-ldap-none/src/integration-test/java/org/springframework/security/LdapServerBeanDefinitionParserTests.java +++ b/itest/ldap/embedded-ldap-none/src/integration-test/java/org/springframework/security/LdapServerBeanDefinitionParserTests.java @@ -17,21 +17,18 @@ package org.springframework.security; import org.junit.After; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.springframework.beans.factory.BeanDefinitionStoreException; import org.springframework.context.support.ClassPathXmlApplicationContext; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + /** * @author EddĂș MelĂ©ndez */ public class LdapServerBeanDefinitionParserTests { - @Rule - public ExpectedException thrown = ExpectedException.none(); - private ClassPathXmlApplicationContext context; @After @@ -44,10 +41,9 @@ public class LdapServerBeanDefinitionParserTests { @Test public void apacheDirectoryServerIsStartedByDefault() { - this.thrown.expect(BeanDefinitionStoreException.class); - this.thrown.expectMessage("Embedded LDAP server is not provided"); - - this.context = new ClassPathXmlApplicationContext("applicationContext-security.xml"); + assertThatExceptionOfType(BeanDefinitionStoreException.class) + .isThrownBy(() -> this.context = new ClassPathXmlApplicationContext("applicationContext-security.xml")) + .withMessageContaining("Embedded LDAP server is not provided"); } } diff --git a/ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java b/ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java index 1c4422344f..72846149c2 100644 --- a/ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java +++ b/ldap/src/test/java/org/springframework/security/ldap/authentication/ad/ActiveDirectoryLdapAuthenticationProviderTests.java @@ -30,14 +30,8 @@ import javax.naming.directory.SearchControls; import javax.naming.directory.SearchResult; import org.apache.directory.shared.ldap.util.EmptyEnumeration; -import org.hamcrest.BaseMatcher; -import org.hamcrest.CoreMatchers; -import org.hamcrest.Description; -import org.hamcrest.Matcher; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.springframework.dao.IncorrectResultSizeDataAccessException; @@ -71,9 +65,6 @@ public class ActiveDirectoryLdapAuthenticationProviderTests { public static final String NON_EXISTING_LDAP_PROVIDER = "ldap://192.168.1.201/"; - @Rule - public ExpectedException thrown = ExpectedException.none(); - ActiveDirectoryLdapAuthenticationProvider provider; UsernamePasswordAuthenticationToken joe = new UsernamePasswordAuthenticationToken("joe", "password"); @@ -245,29 +236,10 @@ public class ActiveDirectoryLdapAuthenticationProviderTests { this.provider.contextFactory = createContextFactoryThrowing( new AuthenticationException(msg + dataCode + ", xxxx]")); this.provider.setConvertSubErrorCodesToExceptions(true); - this.thrown.expect(BadCredentialsException.class); - this.thrown.expect(new BaseMatcher() { - private Matcher causeInstance = CoreMatchers - .instanceOf(ActiveDirectoryAuthenticationException.class); - - private Matcher causeDataCode = CoreMatchers.equalTo(dataCode); - - @Override - public boolean matches(Object that) { - Throwable t = (Throwable) that; - ActiveDirectoryAuthenticationException cause = (ActiveDirectoryAuthenticationException) t.getCause(); - return this.causeInstance.matches(cause) && this.causeDataCode.matches(cause.getDataCode()); - } - - @Override - public void describeTo(Description desc) { - desc.appendText("getCause() "); - this.causeInstance.describeTo(desc); - desc.appendText("getCause().getDataCode() "); - this.causeDataCode.describeTo(desc); - } - }); - this.provider.authenticate(this.joe); + assertThatExceptionOfType(BadCredentialsException.class).isThrownBy(() -> this.provider.authenticate(this.joe)) + .withCauseInstanceOf(ActiveDirectoryAuthenticationException.class) + .satisfies((ex) -> assertThat(((ActiveDirectoryAuthenticationException) ex.getCause()).getDataCode()) + .isEqualTo(dataCode)); } @Test(expected = CredentialsExpiredException.class) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java index cbe7c6b05c..fb7a4d5c11 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/authentication/OAuth2LoginAuthenticationProviderTests.java @@ -25,9 +25,7 @@ import java.util.Map; import java.util.Set; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; @@ -52,7 +50,8 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization import org.springframework.security.oauth2.core.user.OAuth2User; import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.CoreMatchers.containsString; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyCollection; import static org.mockito.BDDMockito.given; @@ -79,9 +78,6 @@ public class OAuth2LoginAuthenticationProviderTests { private OAuth2LoginAuthenticationProvider authenticationProvider; - @Rule - public ExpectedException exception = ExpectedException.none(); - @Before @SuppressWarnings("unchecked") public void setUp() { @@ -98,20 +94,19 @@ public class OAuth2LoginAuthenticationProviderTests { @Test public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - new OAuth2LoginAuthenticationProvider(null, this.userService); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2LoginAuthenticationProvider(null, this.userService)); } @Test public void constructorWhenUserServiceIsNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, null); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OAuth2LoginAuthenticationProvider(this.accessTokenResponseClient, null)); } @Test public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - this.authenticationProvider.setAuthoritiesMapper(null); + assertThatIllegalArgumentException().isThrownBy(() -> this.authenticationProvider.setAuthoritiesMapper(null)); } @Test @@ -132,26 +127,26 @@ public class OAuth2LoginAuthenticationProviderTests { @Test public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_REQUEST)); OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error() .errorCode(OAuth2ErrorCodes.INVALID_REQUEST).build(); OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); - this.authenticationProvider - .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate( + new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange))) + .withMessageContaining(OAuth2ErrorCodes.INVALID_REQUEST); } @Test public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_state_parameter")); OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success().state("67890") .build(); OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); - this.authenticationProvider - .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate( + new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange))) + .withMessageContaining("invalid_state_parameter"); } @Test diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java index 17fc44c706..f3c9001a66 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusAuthorizationCodeTokenResponseClientTests.java @@ -21,9 +21,7 @@ import java.time.Instant; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; @@ -40,7 +38,8 @@ import org.springframework.security.oauth2.core.endpoint.TestOAuth2Authorization import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses; import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.CoreMatchers.containsString; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link NimbusAuthorizationCodeTokenResponseClient}. @@ -59,9 +58,6 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { private NimbusAuthorizationCodeTokenResponseClient tokenResponseClient = new NimbusAuthorizationCodeTokenResponseClient(); - @Rule - public ExpectedException exception = ExpectedException.none(); - @Before public void setUp() { this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration() @@ -109,29 +105,27 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { @Test public void getTokenResponseWhenRedirectUriMalformedThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); String redirectUri = "http:\\example.com"; OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request() .redirectUri(redirectUri).build(); OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, this.authorizationResponse); - this.tokenResponseClient.getTokenResponse( - new OAuth2AuthorizationCodeGrantRequest(this.clientRegistrationBuilder.build(), authorizationExchange)); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), authorizationExchange))); } @Test public void getTokenResponseWhenTokenUriMalformedThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); String tokenUri = "http:\\provider.com\\oauth2\\token"; this.clientRegistrationBuilder.tokenUri(tokenUri); - this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), this.authorizationExchange)); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange))); } @Test public void getTokenResponseWhenSuccessResponseInvalidThenThrowOAuth2AuthorizationException() throws Exception { - this.exception.expect(OAuth2AuthorizationException.class); - this.exception.expectMessage(containsString("invalid_token_response")); MockWebServer server = new MockWebServer(); // @formatter:off String accessTokenSuccessResponse = "{\n" @@ -149,8 +143,10 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { String tokenUri = server.url("/oauth2/token").toString(); this.clientRegistrationBuilder.tokenUri(tokenUri); try { - this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), this.authorizationExchange)); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange))) + .withMessageContaining("invalid_token_response"); } finally { server.shutdown(); @@ -159,17 +155,15 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { @Test public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthorizationException() { - this.exception.expect(OAuth2AuthorizationException.class); String tokenUri = "https://invalid-provider.com/oauth2/token"; this.clientRegistrationBuilder.tokenUri(tokenUri); - this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), this.authorizationExchange)); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange))); } @Test public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() throws Exception { - this.exception.expect(OAuth2AuthorizationException.class); - this.exception.expectMessage(containsString("unauthorized_client")); MockWebServer server = new MockWebServer(); // @formatter:off String accessTokenErrorResponse = "{\n" @@ -182,8 +176,10 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { String tokenUri = server.url("/oauth2/token").toString(); this.clientRegistrationBuilder.tokenUri(tokenUri); try { - this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), this.authorizationExchange)); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange))) + .withMessageContaining("unauthorized_client"); } finally { server.shutdown(); @@ -193,16 +189,16 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { // gh-5594 @Test public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() throws Exception { - this.exception.expect(OAuth2AuthorizationException.class); - this.exception.expectMessage(containsString("server_error")); MockWebServer server = new MockWebServer(); server.enqueue(new MockResponse().setResponseCode(500)); server.start(); String tokenUri = server.url("/oauth2/token").toString(); this.clientRegistrationBuilder.tokenUri(tokenUri); try { - this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), this.authorizationExchange)); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange))) + .withMessageContaining("server_error"); } finally { server.shutdown(); @@ -212,8 +208,6 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { @Test public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() throws Exception { - this.exception.expect(OAuth2AuthorizationException.class); - this.exception.expectMessage(containsString("invalid_token_response")); MockWebServer server = new MockWebServer(); // @formatter:off String accessTokenSuccessResponse = "{\n" @@ -228,8 +222,10 @@ public class NimbusAuthorizationCodeTokenResponseClientTests { String tokenUri = server.url("/oauth2/token").toString(); this.clientRegistrationBuilder.tokenUri(tokenUri); try { - this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( - this.clientRegistrationBuilder.build(), this.authorizationExchange)); + assertThatExceptionOfType(OAuth2AuthorizationException.class) + .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(new OAuth2AuthorizationCodeGrantRequest( + this.clientRegistrationBuilder.build(), this.authorizationExchange))) + .withMessageContaining("invalid_token_response"); } finally { server.shutdown(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java index 2b0e03fa6a..1063f9d5ac 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcAuthorizationCodeAuthenticationProviderTests.java @@ -28,9 +28,7 @@ import java.util.Map; import java.util.Set; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.mockito.ArgumentCaptor; import org.mockito.stubbing.Answer; @@ -64,7 +62,8 @@ import org.springframework.security.oauth2.jwt.JwtException; import org.springframework.security.oauth2.jwt.TestJwts; import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.CoreMatchers.containsString; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyCollection; import static org.mockito.ArgumentMatchers.anyString; @@ -100,9 +99,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { private String nonceHash; - @Rule - public ExpectedException exception = ExpectedException.none(); - @Before @SuppressWarnings("unchecked") public void setUp() { @@ -138,26 +134,24 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { @Test public void constructorWhenAccessTokenResponseClientIsNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - new OidcAuthorizationCodeAuthenticationProvider(null, this.userService); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcAuthorizationCodeAuthenticationProvider(null, this.userService)); } @Test public void constructorWhenUserServiceIsNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, null); + assertThatIllegalArgumentException().isThrownBy( + () -> new OidcAuthorizationCodeAuthenticationProvider(this.accessTokenResponseClient, null)); } @Test public void setJwtDecoderFactoryWhenNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - this.authenticationProvider.setJwtDecoderFactory(null); + assertThatIllegalArgumentException().isThrownBy(() -> this.authenticationProvider.setJwtDecoderFactory(null)); } @Test public void setAuthoritiesMapperWhenAuthoritiesMapperIsNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - this.authenticationProvider.setAuthoritiesMapper(null); + assertThatIllegalArgumentException().isThrownBy(() -> this.authenticationProvider.setAuthoritiesMapper(null)); } @Test @@ -181,8 +175,6 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { @Test public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_SCOPE)); // @formatter:off OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error() .errorCode(OAuth2ErrorCodes.INVALID_SCOPE) @@ -190,14 +182,14 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { // @formatter:on OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); - this.authenticationProvider - .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate( + new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange))) + .withMessageContaining(OAuth2ErrorCodes.INVALID_SCOPE); } @Test public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_state_parameter")); // @formatter:off OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success() .state("89012") @@ -205,14 +197,14 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { // @formatter:on OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, authorizationResponse); - this.authenticationProvider - .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate( + new OAuth2LoginAuthenticationToken(this.clientRegistration, authorizationExchange))) + .withMessageContaining("invalid_state_parameter"); } @Test public void authenticateWhenTokenResponseDoesNotContainIdTokenThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_id_token")); // @formatter:off OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse .withResponse(this.accessTokenSuccessResponse()) @@ -220,38 +212,38 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { .build(); // @formatter:on given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); - this.authenticationProvider - .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate( + new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange))) + .withMessageContaining("invalid_id_token"); } @Test public void authenticateWhenJwkSetUriNotSetThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("missing_signature_verifier")); // @formatter:off ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() .jwkSetUri(null) .build(); // @formatter:on - this.authenticationProvider - .authenticate(new OAuth2LoginAuthenticationToken(clientRegistration, this.authorizationExchange)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate( + new OAuth2LoginAuthenticationToken(clientRegistration, this.authorizationExchange))) + .withMessageContaining("missing_signature_verifier"); } @Test public void authenticateWhenIdTokenValidationErrorThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("[invalid_id_token] ID Token Validation Error")); JwtDecoder jwtDecoder = mock(JwtDecoder.class); given(jwtDecoder.decode(anyString())).willThrow(new JwtException("ID Token Validation Error")); this.authenticationProvider.setJwtDecoderFactory((registration) -> jwtDecoder); - this.authenticationProvider - .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate( + new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange))) + .withMessageContaining("[invalid_id_token] ID Token Validation Error"); } @Test public void authenticateWhenIdTokenInvalidNonceThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("[invalid_nonce]")); Map claims = new HashMap<>(); claims.put(IdTokenClaimNames.ISS, "https://provider.com"); claims.put(IdTokenClaimNames.SUB, "subject1"); @@ -259,8 +251,10 @@ public class OidcAuthorizationCodeAuthenticationProviderTests { claims.put(IdTokenClaimNames.AZP, "client1"); claims.put(IdTokenClaimNames.NONCE, "invalid-nonce-hash"); this.setUpIdToken(claims); - this.authenticationProvider - .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.authenticationProvider.authenticate( + new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange))) + .withMessageContaining("[invalid_nonce]"); } @Test diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index 3693d635be..bc1a4fd6f2 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -29,9 +29,7 @@ import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; import org.junit.After; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; @@ -56,8 +54,8 @@ import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.hamcrest.CoreMatchers.containsString; import static org.mockito.ArgumentMatchers.same; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -80,9 +78,6 @@ public class OidcUserServiceTests { private MockWebServer server; - @Rule - public ExpectedException exception = ExpectedException.none(); - @Before public void setup() throws Exception { this.server = new MockWebServer(); @@ -133,8 +128,7 @@ public class OidcUserServiceTests { @Test public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - this.userService.loadUser(null); + assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null)); } @Test @@ -260,8 +254,6 @@ public class OidcUserServiceTests { // gh-5447 @Test public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_user_info_response")); // @formatter:off String userInfoResponse = "{\n" + " \"email\": \"full_name@provider.com\",\n" @@ -272,25 +264,26 @@ public class OidcUserServiceTests { String userInfoUri = this.server.url("/user").toString(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) .userNameAttributeName(StandardClaimNames.EMAIL).build(); - this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.userService + .loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken))) + .withMessageContaining("invalid_user_info_response"); } @Test public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectNotSameAsIdTokenSubjectThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("invalid_user_info_response")); String userInfoResponse = "{\n" + " \"sub\": \"other-subject\"\n" + "}\n"; this.server.enqueue(jsonResponse(userInfoResponse)); String userInfoUri = this.server.url("/user").toString(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.userService + .loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken))) + .withMessageContaining("invalid_user_info_response"); } @Test public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString( - "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); // @formatter:off String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" @@ -304,28 +297,34 @@ public class OidcUserServiceTests { this.server.enqueue(jsonResponse(userInfoResponse)); String userInfoUri = this.server.url("/user").toString(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.userService + .loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken))) + .withMessageContaining( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"); } @Test public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString( - "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error")); this.server.enqueue(new MockResponse().setResponseCode(500)); String userInfoUri = this.server.url("/user").toString(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.userService + .loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken))) + .withMessageContaining( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"); } @Test public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString( - "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); String userInfoUri = "https://invalid-provider.com/user"; ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy(() -> this.userService + .loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken))) + .withMessageContaining( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"); } @Test diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java index c1279db529..9ed6d0b119 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/CustomUserTypesOAuth2UserServiceTests.java @@ -26,9 +26,7 @@ import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.junit.After; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; @@ -43,7 +41,8 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.user.OAuth2User; import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.CoreMatchers.containsString; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link CustomUserTypesOAuth2UserService}. @@ -61,9 +60,6 @@ public class CustomUserTypesOAuth2UserServiceTests { private MockWebServer server; - @Rule - public ExpectedException exception = ExpectedException.none(); - @Before public void setUp() throws Exception { this.server = new MockWebServer(); @@ -86,32 +82,28 @@ public class CustomUserTypesOAuth2UserServiceTests { @Test public void constructorWhenCustomUserTypesIsNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - new CustomUserTypesOAuth2UserService(null); + assertThatIllegalArgumentException().isThrownBy(() -> new CustomUserTypesOAuth2UserService(null)); } @Test public void constructorWhenCustomUserTypesIsEmptyThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - new CustomUserTypesOAuth2UserService(Collections.emptyMap()); + assertThatIllegalArgumentException() + .isThrownBy(() -> new CustomUserTypesOAuth2UserService(Collections.emptyMap())); } @Test public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - this.userService.setRequestEntityConverter(null); + assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setRequestEntityConverter(null)); } @Test public void setRestOperationsWhenNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - this.userService.setRestOperations(null); + assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setRestOperations(null)); } @Test public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - this.userService.loadUser(null); + assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null)); } @Test @@ -151,9 +143,6 @@ public class CustomUserTypesOAuth2UserServiceTests { @Test public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString( - "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); // @formatter:off String userInfoResponse = "{\n" + " \"id\": \"12345\",\n" @@ -166,28 +155,34 @@ public class CustomUserTypesOAuth2UserServiceTests { this.server.enqueue(jsonResponse(userInfoResponse)); String userInfoUri = this.server.url("/user").toString(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy( + () -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken))) + .withMessageContaining( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"); } @Test public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString( - "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error")); this.server.enqueue(new MockResponse().setResponseCode(500)); String userInfoUri = this.server.url("/user").toString(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy( + () -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken))) + .withMessageContaining( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"); } @Test public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString( - "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); String userInfoUri = "https://invalid-provider.com/user"; ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); - this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy( + () -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken))) + .withMessageContaining( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"); } private ClientRegistration.Builder withRegistrationId(String registrationId) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java index cb6da9fa93..a2df2f4e44 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java @@ -26,9 +26,7 @@ import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; import org.junit.After; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.converter.Converter; @@ -51,7 +49,8 @@ import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; import org.springframework.web.client.RestOperations; import static org.assertj.core.api.Assertions.assertThat; -import static org.hamcrest.CoreMatchers.containsString; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.BDDMockito.given; @@ -73,9 +72,6 @@ public class DefaultOAuth2UserServiceTests { private MockWebServer server; - @Rule - public ExpectedException exception = ExpectedException.none(); - @Before public void setup() throws Exception { this.server = new MockWebServer(); @@ -95,40 +91,39 @@ public class DefaultOAuth2UserServiceTests { @Test public void setRequestEntityConverterWhenNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - this.userService.setRequestEntityConverter(null); + assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setRequestEntityConverter(null)); } @Test public void setRestOperationsWhenNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - this.userService.setRestOperations(null); + assertThatIllegalArgumentException().isThrownBy(() -> this.userService.setRestOperations(null)); } @Test public void loadUserWhenUserRequestIsNullThenThrowIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - this.userService.loadUser(null); + assertThatIllegalArgumentException().isThrownBy(() -> this.userService.loadUser(null)); } @Test public void loadUserWhenUserInfoUriIsNullThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("missing_user_info_uri")); ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); - this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy( + () -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken))) + .withMessageContaining("missing_user_info_uri"); } @Test public void loadUserWhenUserNameAttributeNameIsNullThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString("missing_user_name_attribute")); // @formatter:off ClientRegistration clientRegistration = this.clientRegistrationBuilder .userInfoUri("https://provider.com/user") .build(); // @formatter:on - this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy( + () -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken))) + .withMessageContaining("missing_user_name_attribute"); } @Test @@ -165,9 +160,6 @@ public class DefaultOAuth2UserServiceTests { @Test public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString( - "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); // @formatter:off String userInfoResponse = "{\n" + " \"user-name\": \"user1\",\n" @@ -182,16 +174,15 @@ public class DefaultOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy( + () -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken))) + .withMessageContaining( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"); } @Test public void loadUserWhenUserInfoErrorResponseWwwAuthenticateHeaderThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString( - "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); - this.exception.expectMessage( - containsString("Error Code: insufficient_scope, Error Description: The access token expired")); String wwwAuthenticateHeader = "Bearer realm=\"auth-realm\" error=\"insufficient_scope\" error_description=\"The access token expired\""; MockResponse response = new MockResponse(); response.setHeader(HttpHeaders.WWW_AUTHENTICATE, wwwAuthenticateHeader); @@ -200,15 +191,16 @@ public class DefaultOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy( + () -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken))) + .withMessageContaining( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource") + .withMessageContaining("Error Code: insufficient_scope, Error Description: The access token expired"); } @Test public void loadUserWhenUserInfoErrorResponseThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString( - "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); - this.exception.expectMessage(containsString("Error Code: invalid_token")); // @formatter:off String userInfoErrorResponse = "{\n" + " \"error\": \"invalid_token\"\n" @@ -218,30 +210,37 @@ public class DefaultOAuth2UserServiceTests { String userInfoUri = this.server.url("/user").toString(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy( + () -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken))) + .withMessageContaining( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource") + .withMessageContaining("Error Code: invalid_token"); } @Test public void loadUserWhenServerErrorThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString( - "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error")); this.server.enqueue(new MockResponse().setResponseCode(500)); String userInfoUri = this.server.url("/user").toString(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy( + () -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken))) + .withMessageContaining( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource: 500 Server Error"); } @Test public void loadUserWhenUserInfoUriInvalidThenThrowOAuth2AuthenticationException() { - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString( - "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); String userInfoUri = "https://invalid-provider.com/user"; ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy( + () -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken))) + .withMessageContaining( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"); } // gh-5294 @@ -348,17 +347,18 @@ public class DefaultOAuth2UserServiceTests { @Test public void loadUserWhenUserInfoSuccessResponseInvalidContentTypeThenThrowOAuth2AuthenticationException() { String userInfoUri = this.server.url("/user").toString(); - this.exception.expect(OAuth2AuthenticationException.class); - this.exception.expectMessage(containsString( - "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource " - + "from '" + userInfoUri + "': response contains invalid content type 'text/plain'.")); MockResponse response = new MockResponse(); response.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_PLAIN_VALUE); response.setBody("invalid content type"); this.server.enqueue(response); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) .userInfoAuthenticationMethod(AuthenticationMethod.HEADER).userNameAttributeName("user-name").build(); - this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); + assertThatExceptionOfType(OAuth2AuthenticationException.class) + .isThrownBy( + () -> this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken))) + .withMessageContaining( + "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource " + + "from '" + userInfoUri + "': response contains invalid content type 'text/plain'."); } private DefaultOAuth2UserService withMockResponse(Map response) { diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2X509CredentialTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2X509CredentialTests.java index 5f41ee0e08..3ea40c5d80 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2X509CredentialTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2X509CredentialTests.java @@ -23,17 +23,15 @@ import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.springframework.security.converter.RsaKeyConverters; import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType; -public class Saml2X509CredentialTests { +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; - @Rule - public ExpectedException exception = ExpectedException.none(); +public class Saml2X509CredentialTests { private PrivateKey key; @@ -99,98 +97,90 @@ public class Saml2X509CredentialTests { @Test public void constructorWhenRelyingPartyWithoutCredentialsThenItFails() { - this.exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING); + assertThatIllegalArgumentException().isThrownBy( + () -> new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING)); } @Test public void constructorWhenRelyingPartyWithoutPrivateKeyThenItFails() { - this.exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING); + assertThatIllegalArgumentException() + .isThrownBy(() -> new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING)); } @Test public void constructorWhenRelyingPartyWithoutCertificateThenItFails() { - this.exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING); + assertThatIllegalArgumentException() + .isThrownBy(() -> new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING)); } @Test public void constructorWhenAssertingPartyWithoutCertificateThenItFails() { - this.exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING); + assertThatIllegalArgumentException() + .isThrownBy(() -> new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING)); } @Test public void constructorWhenRelyingPartyWithEncryptionUsageThenItFails() { - this.exception.expect(IllegalStateException.class); - new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION); + assertThatIllegalStateException().isThrownBy( + () -> new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION)); } @Test public void constructorWhenRelyingPartyWithVerificationUsageThenItFails() { - this.exception.expect(IllegalStateException.class); - new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION); + assertThatIllegalStateException().isThrownBy( + () -> new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION)); } @Test public void constructorWhenAssertingPartyWithSigningUsageThenItFails() { - this.exception.expect(IllegalStateException.class); - new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING); + assertThatIllegalStateException() + .isThrownBy(() -> new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING)); } @Test public void constructorWhenAssertingPartyWithDecryptionUsageThenItFails() { - this.exception.expect(IllegalStateException.class); - new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION); + assertThatIllegalStateException() + .isThrownBy(() -> new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION)); } @Test public void factoryWhenRelyingPartyForSigningWithoutCredentialsThenItFails() { - this.exception.expect(IllegalArgumentException.class); - Saml2X509Credential.signing(null, null); + assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.signing(null, null)); } @Test public void factoryWhenRelyingPartyForSigningWithoutPrivateKeyThenItFails() { - this.exception.expect(IllegalArgumentException.class); - Saml2X509Credential.signing(null, this.certificate); + assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.signing(null, this.certificate)); } @Test public void factoryWhenRelyingPartyForSigningWithoutCertificateThenItFails() { - this.exception.expect(IllegalArgumentException.class); - Saml2X509Credential.signing(this.key, null); + assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.signing(this.key, null)); } @Test public void factoryWhenRelyingPartyForDecryptionWithoutCredentialsThenItFails() { - this.exception.expect(IllegalArgumentException.class); - Saml2X509Credential.decryption(null, null); + assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.decryption(null, null)); } @Test public void factoryWhenRelyingPartyForDecryptionWithoutPrivateKeyThenItFails() { - this.exception.expect(IllegalArgumentException.class); - Saml2X509Credential.decryption(null, this.certificate); + assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.decryption(null, this.certificate)); } @Test public void factoryWhenRelyingPartyForDecryptionWithoutCertificateThenItFails() { - this.exception.expect(IllegalArgumentException.class); - Saml2X509Credential.decryption(this.key, null); + assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.decryption(this.key, null)); } @Test public void factoryWhenAssertingPartyForVerificationWithoutCertificateThenItFails() { - this.exception.expect(IllegalArgumentException.class); - Saml2X509Credential.verification(null); + assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.verification(null)); } @Test public void factoryWhenAssertingPartyForEncryptionWithoutCertificateThenItFails() { - this.exception.expect(IllegalArgumentException.class); - Saml2X509Credential.encryption(null); + assertThatIllegalArgumentException().isThrownBy(() -> Saml2X509Credential.encryption(null)); } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/Saml2X509CredentialTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/Saml2X509CredentialTests.java index dd9d9ba715..5c3ac1231c 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/Saml2X509CredentialTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/credentials/Saml2X509CredentialTests.java @@ -23,17 +23,15 @@ import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.springframework.security.converter.RsaKeyConverters; import org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType; -public class Saml2X509CredentialTests { +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; - @Rule - public ExpectedException exception = ExpectedException.none(); +public class Saml2X509CredentialTests { private Saml2X509Credential credential; @@ -97,50 +95,50 @@ public class Saml2X509CredentialTests { @Test public void constructorWhenRelyingPartyWithoutCredentialsThenItFails() { - this.exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING); + assertThatIllegalArgumentException().isThrownBy( + () -> new Saml2X509Credential(null, (X509Certificate) null, Saml2X509CredentialType.SIGNING)); } @Test public void constructorWhenRelyingPartyWithoutPrivateKeyThenItFails() { - this.exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING); + assertThatIllegalArgumentException() + .isThrownBy(() -> new Saml2X509Credential(null, this.certificate, Saml2X509CredentialType.SIGNING)); } @Test public void constructorWhenRelyingPartyWithoutCertificateThenItFails() { - this.exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING); + assertThatIllegalArgumentException() + .isThrownBy(() -> new Saml2X509Credential(this.key, null, Saml2X509CredentialType.SIGNING)); } @Test public void constructorWhenAssertingPartyWithoutCertificateThenItFails() { - this.exception.expect(IllegalArgumentException.class); - new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING); + assertThatIllegalArgumentException() + .isThrownBy(() -> new Saml2X509Credential(null, Saml2X509CredentialType.SIGNING)); } @Test public void constructorWhenRelyingPartyWithEncryptionUsageThenItFails() { - this.exception.expect(IllegalStateException.class); - new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION); + assertThatIllegalStateException().isThrownBy( + () -> new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.ENCRYPTION)); } @Test public void constructorWhenRelyingPartyWithVerificationUsageThenItFails() { - this.exception.expect(IllegalStateException.class); - new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION); + assertThatIllegalStateException().isThrownBy( + () -> new Saml2X509Credential(this.key, this.certificate, Saml2X509CredentialType.VERIFICATION)); } @Test public void constructorWhenAssertingPartyWithSigningUsageThenItFails() { - this.exception.expect(IllegalStateException.class); - new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING); + assertThatIllegalStateException() + .isThrownBy(() -> new Saml2X509Credential(this.certificate, Saml2X509CredentialType.SIGNING)); } @Test public void constructorWhenAssertingPartyWithDecryptionUsageThenItFails() { - this.exception.expect(IllegalStateException.class); - new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION); + assertThatIllegalStateException() + .isThrownBy(() -> new Saml2X509Credential(this.certificate, Saml2X509CredentialType.DECRYPTION)); } } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java index a1bac457fd..e0ebdb6e25 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java @@ -26,18 +26,14 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import javax.xml.namespace.QName; import net.shibboleth.utilities.java.support.xml.SerializeSupport; -import org.hamcrest.BaseMatcher; -import org.hamcrest.Description; -import org.hamcrest.Matcher; import org.joda.time.DateTime; import org.joda.time.Duration; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.core.xml.io.Marshaller; @@ -93,9 +89,6 @@ public class OpenSamlAuthenticationProviderTests { private Saml2Authentication authentication = new Saml2Authentication(this.principal, "response", Collections.emptyList()); - @Rule - public ExpectedException exception = ExpectedException.none(); - @Test public void supportsWhenSaml2AuthenticationTokenThenReturnTrue() { assertThat(this.provider.supports(Saml2AuthenticationToken.class)) @@ -113,53 +106,56 @@ public class OpenSamlAuthenticationProviderTests { @Test public void authenticateWhenUnknownDataClassThenThrowAuthenticationException() { - this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA)); Assertion assertion = (Assertion) XMLObjectProviderRegistrySupport.getBuilderFactory() .getBuilder(Assertion.DEFAULT_ELEMENT_NAME).buildObject(Assertion.DEFAULT_ELEMENT_NAME); - this.provider - .authenticate(token(serialize(assertion), TestSaml2X509Credentials.relyingPartyVerifyingCredential())); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate( + token(serialize(assertion), TestSaml2X509Credentials.relyingPartyVerifyingCredential()))) + .satisfies(errorOf(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA)); } @Test public void authenticateWhenXmlErrorThenThrowAuthenticationException() { - this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA)); Saml2AuthenticationToken token = token("invalid xml", TestSaml2X509Credentials.relyingPartyVerifyingCredential()); - this.provider.authenticate(token); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA)); } @Test public void authenticateWhenInvalidDestinationThenThrowAuthenticationException() { - this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_DESTINATION)); Response response = TestOpenSamlObjects.response(DESTINATION + "invalid", ASSERTING_PARTY_ENTITY_ID); response.getAssertions().add(TestOpenSamlObjects.assertion()); TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); - this.provider.authenticate(token); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.INVALID_DESTINATION)); } @Test public void authenticateWhenNoAssertionsPresentThenThrowAuthenticationException() { - this.exception.expect( - authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response.")); Saml2AuthenticationToken token = token(TestOpenSamlObjects.response(), TestSaml2X509Credentials.assertingPartySigningCredential()); - this.provider.authenticate(token); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA, "No assertions found in response.")); } @Test public void authenticateWhenInvalidSignatureOnAssertionThenThrowAuthenticationException() { - this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE)); Response response = TestOpenSamlObjects.response(); response.getAssertions().add(TestOpenSamlObjects.assertion()); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); - this.provider.authenticate(token); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE)); } @Test public void authenticateWhenOpenSAMLValidationErrorThenThrowAuthenticationException() throws Exception { - this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_ASSERTION)); Response response = TestOpenSamlObjects.response(); Assertion assertion = TestOpenSamlObjects.assertion(); assertion.getSubject().getSubjectConfirmations().get(0).getSubjectConfirmationData() @@ -168,12 +164,13 @@ public class OpenSamlAuthenticationProviderTests { RELYING_PARTY_ENTITY_ID); response.getAssertions().add(assertion); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); - this.provider.authenticate(token); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.INVALID_ASSERTION)); } @Test public void authenticateWhenMissingSubjectThenThrowAuthenticationException() { - this.exception.expect(authenticationMatcher(Saml2ErrorCodes.SUBJECT_NOT_FOUND)); Response response = TestOpenSamlObjects.response(); Assertion assertion = TestOpenSamlObjects.assertion(); assertion.setSubject(null); @@ -181,12 +178,13 @@ public class OpenSamlAuthenticationProviderTests { RELYING_PARTY_ENTITY_ID); response.getAssertions().add(assertion); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); - this.provider.authenticate(token); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.SUBJECT_NOT_FOUND)); } @Test public void authenticateWhenUsernameMissingThenThrowAuthenticationException() throws Exception { - this.exception.expect(authenticationMatcher(Saml2ErrorCodes.SUBJECT_NOT_FOUND)); Response response = TestOpenSamlObjects.response(); Assertion assertion = TestOpenSamlObjects.assertion(); assertion.getSubject().getNameID().setValue(null); @@ -194,7 +192,9 @@ public class OpenSamlAuthenticationProviderTests { RELYING_PARTY_ENTITY_ID); response.getAssertions().add(assertion); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential()); - this.provider.authenticate(token); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.SUBJECT_NOT_FOUND)); } @Test @@ -236,13 +236,14 @@ public class OpenSamlAuthenticationProviderTests { @Test public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() throws Exception { - this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE)); Response response = TestOpenSamlObjects.response(); EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyDecryptingCredential()); - this.provider.authenticate(token); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE)); } @Test @@ -290,28 +291,28 @@ public class OpenSamlAuthenticationProviderTests { @Test public void authenticateWhenDecryptionKeysAreMissingThenThrowAuthenticationException() throws Exception { - this.exception - .expect(authenticationMatcher(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData")); Response response = TestOpenSamlObjects.response(); EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); Saml2AuthenticationToken token = token(serialize(response), TestSaml2X509Credentials.relyingPartyVerifyingCredential()); - this.provider.authenticate(token); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData")); } @Test public void authenticateWhenDecryptionKeysAreWrongThenThrowAuthenticationException() throws Exception { - this.exception - .expect(authenticationMatcher(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData")); Response response = TestOpenSamlObjects.response(); EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(), TestSaml2X509Credentials.assertingPartyEncryptingCredential()); response.getEncryptedAssertions().add(encryptedAssertion); Saml2AuthenticationToken token = token(serialize(response), TestSaml2X509Credentials.assertingPartyPrivateCredential()); - this.provider.authenticate(token); + assertThatExceptionOfType(Saml2AuthenticationException.class) + .isThrownBy(() -> this.provider.authenticate(token)) + .satisfies(errorOf(Saml2ErrorCodes.DECRYPTION_ERROR, "Failed to decrypt EncryptedData")); } @Test @@ -487,33 +488,15 @@ public class OpenSamlAuthenticationProviderTests { } } - private Matcher authenticationMatcher(String code) { - return authenticationMatcher(code, null); - } - - private Matcher authenticationMatcher(String code, String description) { - return new BaseMatcher() { - @Override - public boolean matches(Object item) { - if (!(item instanceof Saml2AuthenticationException)) { - return false; - } - Saml2AuthenticationException ex = (Saml2AuthenticationException) item; - if (!code.equals(ex.getError().getErrorCode())) { - return false; - } - if (StringUtils.hasText(description)) { - if (!description.equals(ex.getError().getDescription())) { - return false; - } - } - return true; - } + private Consumer errorOf(String errorCode) { + return errorOf(errorCode, null); + } - @Override - public void describeTo(Description desc) { - String excepting = "Saml2AuthenticationException[code=" + code + "; description=" + description + "]"; - desc.appendText(excepting); + private Consumer errorOf(String errorCode, String description) { + return (ex) -> { + assertThat(ex.getError().getErrorCode()).isEqualTo(errorCode); + if (StringUtils.hasText(description)) { + assertThat(ex.getError().getDescription()).isEqualTo(description); } }; } diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java index 99fc66b954..8ddb48ef79 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java @@ -21,9 +21,7 @@ import java.nio.charset.StandardCharsets; import org.junit.Assert; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.saml2.core.AuthnRequest; @@ -39,7 +37,6 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.hamcrest.CoreMatchers.containsString; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -61,9 +58,6 @@ public class OpenSamlAuthenticationRequestFactoryTests { private AuthnRequestUnmarshaller unmarshaller; - @Rule - public ExpectedException exception = ExpectedException.none(); - @Before public void setUp() { this.relyingPartyRegistrationBuilder = RelyingPartyRegistration.withRegistrationId("id") @@ -160,9 +154,8 @@ public class OpenSamlAuthenticationRequestFactoryTests { @Test public void createAuthenticationRequestWhenSetUnsupportredUriThenThrowsIllegalArgumentException() { - this.exception.expect(IllegalArgumentException.class); - this.exception.expectMessage(containsString("my-invalid-binding")); - this.factory.setProtocolBinding("my-invalid-binding"); + assertThatIllegalArgumentException().isThrownBy(() -> this.factory.setProtocolBinding("my-invalid-binding")) + .withMessageContaining("my-invalid-binding"); } @Test diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java index 826b04fa22..1563951b15 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java @@ -20,9 +20,7 @@ import javax.servlet.http.HttpServletResponse; import org.junit.Assert; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -43,9 +41,6 @@ public class Saml2WebSsoAuthenticationFilterTests { private HttpServletResponse response = new MockHttpServletResponse(); - @Rule - public ExpectedException exception = ExpectedException.none(); - @Before public void setup() { this.filter = new Saml2WebSsoAuthenticationFilter(this.repository); @@ -55,9 +50,9 @@ public class Saml2WebSsoAuthenticationFilterTests { @Test public void constructingFilterWithMissingRegistrationIdVariableThenThrowsException() { - this.exception.expect(IllegalArgumentException.class); - this.exception.expectMessage("filterProcessesUrl must contain a {registrationId} match variable"); - this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/url/missing/variable"); + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy( + () -> this.filter = new Saml2WebSsoAuthenticationFilter(this.repository, "/url/missing/variable")) + .withMessage("filterProcessesUrl must contain a {registrationId} match variable"); } @Test diff --git a/web/src/test/java/org/springframework/security/web/authentication/DelegatingAuthenticationFailureHandlerTests.java b/web/src/test/java/org/springframework/security/web/authentication/DelegatingAuthenticationFailureHandlerTests.java index 196028a094..5117073a42 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/DelegatingAuthenticationFailureHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/DelegatingAuthenticationFailureHandlerTests.java @@ -22,9 +22,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; @@ -35,6 +33,7 @@ import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.authentication.CredentialsExpiredException; import org.springframework.security.core.AuthenticationException; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; @@ -48,9 +47,6 @@ import static org.mockito.Mockito.verifyZeroInteractions; @RunWith(MockitoJUnitRunner.class) public class DelegatingAuthenticationFailureHandlerTests { - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Mock private AuthenticationFailureHandler handler1; @@ -110,24 +106,24 @@ public class DelegatingAuthenticationFailureHandlerTests { @Test public void handlersIsNull() { - this.thrown.expect(IllegalArgumentException.class); - this.thrown.expectMessage("handlers cannot be null or empty"); - new DelegatingAuthenticationFailureHandler(null, this.defaultHandler); + assertThatIllegalArgumentException() + .isThrownBy(() -> new DelegatingAuthenticationFailureHandler(null, this.defaultHandler)) + .withMessage("handlers cannot be null or empty"); } @Test public void handlersIsEmpty() { - this.thrown.expect(IllegalArgumentException.class); - this.thrown.expectMessage("handlers cannot be null or empty"); - new DelegatingAuthenticationFailureHandler(this.handlers, this.defaultHandler); + assertThatIllegalArgumentException() + .isThrownBy(() -> new DelegatingAuthenticationFailureHandler(this.handlers, this.defaultHandler)) + .withMessage("handlers cannot be null or empty"); } @Test public void defaultHandlerIsNull() { - this.thrown.expect(IllegalArgumentException.class); - this.thrown.expectMessage("defaultHandler cannot be null"); this.handlers.put(BadCredentialsException.class, this.handler1); - new DelegatingAuthenticationFailureHandler(this.handlers, null); + assertThatIllegalArgumentException() + .isThrownBy(() -> new DelegatingAuthenticationFailureHandler(this.handlers, null)) + .withMessage("defaultHandler cannot be null"); } } diff --git a/web/src/test/java/org/springframework/security/web/authentication/logout/CompositeLogoutHandlerTests.java b/web/src/test/java/org/springframework/security/web/authentication/logout/CompositeLogoutHandlerTests.java index 0f6298b12c..40fc27fb11 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/logout/CompositeLogoutHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/logout/CompositeLogoutHandlerTests.java @@ -22,9 +22,7 @@ import java.util.List; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.mockito.InOrder; import org.springframework.security.core.Authentication; @@ -45,14 +43,10 @@ import static org.mockito.Mockito.verify; */ public class CompositeLogoutHandlerTests { - @Rule - public ExpectedException exception = ExpectedException.none(); - @Test public void buildEmptyCompositeLogoutHandlerThrowsException() { - this.exception.expect(IllegalArgumentException.class); - this.exception.expectMessage("LogoutHandlers are required"); - new CompositeLogoutHandler(); + assertThatIllegalArgumentException().isThrownBy(() -> new CompositeLogoutHandler()) + .withMessage("LogoutHandlers are required"); } @Test diff --git a/web/src/test/java/org/springframework/security/web/authentication/logout/ForwardLogoutSuccessHandlerTests.java b/web/src/test/java/org/springframework/security/web/authentication/logout/ForwardLogoutSuccessHandlerTests.java index 4dfc0d8c73..980ecd8f8a 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/logout/ForwardLogoutSuccessHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/logout/ForwardLogoutSuccessHandlerTests.java @@ -16,15 +16,14 @@ package org.springframework.security.web.authentication.logout; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.core.Authentication; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; /** @@ -34,23 +33,18 @@ import static org.mockito.Mockito.mock; */ public class ForwardLogoutSuccessHandlerTests { - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Test public void invalidTargetUrl() { String targetUrl = "not.valid"; - this.thrown.expect(IllegalArgumentException.class); - this.thrown.expectMessage("'" + targetUrl + "' is not a valid target URL"); - new ForwardLogoutSuccessHandler(targetUrl); + assertThatIllegalArgumentException().isThrownBy(() -> new ForwardLogoutSuccessHandler(targetUrl)) + .withMessage("'" + targetUrl + "' is not a valid target URL"); } @Test public void emptyTargetUrl() { String targetUrl = " "; - this.thrown.expect(IllegalArgumentException.class); - this.thrown.expectMessage("'" + targetUrl + "' is not a valid target URL"); - new ForwardLogoutSuccessHandler(targetUrl); + assertThatIllegalArgumentException().isThrownBy(() -> new ForwardLogoutSuccessHandler(targetUrl)) + .withMessage("'" + targetUrl + "' is not a valid target URL"); } @Test diff --git a/web/src/test/java/org/springframework/security/web/authentication/logout/HeaderWriterLogoutHandlerTests.java b/web/src/test/java/org/springframework/security/web/authentication/logout/HeaderWriterLogoutHandlerTests.java index da1211f284..dd89839d44 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/logout/HeaderWriterLogoutHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/logout/HeaderWriterLogoutHandlerTests.java @@ -17,15 +17,14 @@ package org.springframework.security.web.authentication.logout; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.core.Authentication; import org.springframework.security.web.header.HeaderWriter; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -40,9 +39,6 @@ public class HeaderWriterLogoutHandlerTests { private MockHttpServletRequest request; - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Before public void setup() { this.response = new MockHttpServletResponse(); @@ -51,9 +47,8 @@ public class HeaderWriterLogoutHandlerTests { @Test public void constructorWhenHeaderWriterIsNullThenThrowsException() { - this.thrown.expect(IllegalArgumentException.class); - this.thrown.expectMessage("headerWriter cannot be null"); - new HeaderWriterLogoutHandler(null); + assertThatIllegalArgumentException().isThrownBy(() -> new HeaderWriterLogoutHandler(null)) + .withMessage("headerWriter cannot be null"); } @Test diff --git a/web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java b/web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java index 55a43482a9..18ee35abdf 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/switchuser/SwitchUserFilterTests.java @@ -23,9 +23,7 @@ import javax.servlet.FilterChain; import org.junit.After; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; @@ -49,6 +47,7 @@ import org.springframework.security.web.authentication.SimpleUrlAuthenticationSu import org.springframework.security.web.util.matcher.AnyRequestMatcher; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; @@ -64,9 +63,6 @@ public class SwitchUserFilterTests { private static final List ROLES_12 = AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"); - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Before public void authenticateCurrentUser() { UsernamePasswordAuthenticationToken auth = new UsernamePasswordAuthenticationToken("dano", "hawaii50"); @@ -437,9 +433,8 @@ public class SwitchUserFilterTests { // gh-3697 @Test public void switchAuthorityRoleCannotBeNull() { - this.thrown.expect(IllegalArgumentException.class); - this.thrown.expectMessage("switchAuthorityRole cannot be null"); - switchToUserWithAuthorityRole("dano", null); + assertThatIllegalArgumentException().isThrownBy(() -> switchToUserWithAuthorityRole("dano", null)) + .withMessage("switchAuthorityRole cannot be null"); } // gh-3697 diff --git a/web/src/test/java/org/springframework/security/web/firewall/FirewalledResponseTests.java b/web/src/test/java/org/springframework/security/web/firewall/FirewalledResponseTests.java index 5ded76a468..f7b4fc2a17 100644 --- a/web/src/test/java/org/springframework/security/web/firewall/FirewalledResponseTests.java +++ b/web/src/test/java/org/springframework/security/web/firewall/FirewalledResponseTests.java @@ -20,9 +20,7 @@ import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletResponse; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.Mockito.mock; @@ -35,8 +33,7 @@ import static org.mockito.Mockito.verify; */ public class FirewalledResponseTests { - @Rule - public ExpectedException expectedException = ExpectedException.none(); + private static final String CRLF_MESSAGE = "Invalid characters (CR/LF)"; private HttpServletResponse response; @@ -62,8 +59,8 @@ public class FirewalledResponseTests { @Test public void sendRedirectWhenHasCrlfThenThrowsException() throws Exception { - expectCrlfValidationException(); - this.fwResponse.sendRedirect("/theURL\r\nsomething"); + assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.sendRedirect("/theURL\r\nsomething")) + .withMessageContaining(CRLF_MESSAGE); } @Test @@ -80,14 +77,16 @@ public class FirewalledResponseTests { @Test public void addHeaderWhenHeaderValueHasCrlfThenException() { - expectCrlfValidationException(); - this.fwResponse.addHeader("foo", "abc\r\nContent-Length:100"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.fwResponse.addHeader("foo", "abc\r\nContent-Length:100")) + .withMessageContaining(CRLF_MESSAGE); } @Test public void addHeaderWhenHeaderNameHasCrlfThenException() { - expectCrlfValidationException(); - this.fwResponse.addHeader("abc\r\nContent-Length:100", "bar"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.fwResponse.addHeader("abc\r\nContent-Length:100", "bar")) + .withMessageContaining(CRLF_MESSAGE); } @Test @@ -115,39 +114,39 @@ public class FirewalledResponseTests { return "foo\r\nbar"; } }; - expectCrlfValidationException(); - this.fwResponse.addCookie(cookie); + assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie)) + .withMessageContaining(CRLF_MESSAGE); } @Test public void addCookieWhenCookieValueContainsCrlfThenException() { Cookie cookie = new Cookie("foo", "foo\r\nbar"); - expectCrlfValidationException(); - this.fwResponse.addCookie(cookie); + assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie)) + .withMessageContaining(CRLF_MESSAGE); } @Test public void addCookieWhenCookiePathContainsCrlfThenException() { Cookie cookie = new Cookie("foo", "bar"); cookie.setPath("/foo\r\nbar"); - expectCrlfValidationException(); - this.fwResponse.addCookie(cookie); + assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie)) + .withMessageContaining(CRLF_MESSAGE); } @Test public void addCookieWhenCookieDomainContainsCrlfThenException() { Cookie cookie = new Cookie("foo", "bar"); cookie.setDomain("foo\r\nbar"); - expectCrlfValidationException(); - this.fwResponse.addCookie(cookie); + assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie)) + .withMessageContaining(CRLF_MESSAGE); } @Test public void addCookieWhenCookieCommentContainsCrlfThenException() { Cookie cookie = new Cookie("foo", "bar"); cookie.setComment("foo\r\nbar"); - expectCrlfValidationException(); - this.fwResponse.addCookie(cookie); + assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.addCookie(cookie)) + .withMessageContaining(CRLF_MESSAGE); } @Test @@ -160,11 +159,6 @@ public class FirewalledResponseTests { validateLineEnding("foo\nbar", "bar"); } - private void expectCrlfValidationException() { - this.expectedException.expect(IllegalArgumentException.class); - this.expectedException.expectMessage("Invalid characters (CR/LF)"); - } - private void validateLineEnding(String name, String value) { assertThatIllegalArgumentException().isThrownBy(() -> this.fwResponse.validateCrlf(name, value)); } diff --git a/web/src/test/java/org/springframework/security/web/header/writers/ClearSiteDataHeaderWriterTests.java b/web/src/test/java/org/springframework/security/web/header/writers/ClearSiteDataHeaderWriterTests.java index a559c13282..d55b3326e5 100644 --- a/web/src/test/java/org/springframework/security/web/header/writers/ClearSiteDataHeaderWriterTests.java +++ b/web/src/test/java/org/springframework/security/web/header/writers/ClearSiteDataHeaderWriterTests.java @@ -17,15 +17,14 @@ package org.springframework.security.web.header.writers; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.web.header.writers.ClearSiteDataHeaderWriter.Directive; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; /** * @author Rafiullah Hamedy @@ -40,9 +39,6 @@ public class ClearSiteDataHeaderWriterTests { private MockHttpServletResponse response; - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Before public void setup() { this.request = new MockHttpServletRequest(); @@ -52,9 +48,8 @@ public class ClearSiteDataHeaderWriterTests { @Test public void createInstanceWhenMissingSourceThenThrowsException() { - this.thrown.expect(Exception.class); - this.thrown.expectMessage("directives cannot be empty or null"); - new ClearSiteDataHeaderWriter(); + assertThatExceptionOfType(Exception.class).isThrownBy(() -> new ClearSiteDataHeaderWriter()) + .withMessage("directives cannot be empty or null"); } @Test diff --git a/web/src/test/java/org/springframework/security/web/server/authentication/SwitchUserWebFilterTests.java b/web/src/test/java/org/springframework/security/web/server/authentication/SwitchUserWebFilterTests.java index 2161e8660c..c15a8de07b 100644 --- a/web/src/test/java/org/springframework/security/web/server/authentication/SwitchUserWebFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authentication/SwitchUserWebFilterTests.java @@ -20,9 +20,7 @@ import java.security.Principal; import java.util.Collections; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -55,7 +53,8 @@ import org.springframework.test.util.ReflectionTestUtils; import org.springframework.web.server.WebFilterChain; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.fail; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; @@ -83,9 +82,6 @@ public class SwitchUserWebFilterTests { @Mock private ServerSecurityContextRepository serverSecurityContextRepository; - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); - @Before public void setUp() { this.switchUserWebFilter = new SwitchUserWebFilter(this.userDetailsService, this.successHandler, @@ -183,11 +179,12 @@ public class SwitchUserWebFilterTests { .from(MockServerHttpRequest.post("/login/impersonate")); final WebFilterChain chain = mock(WebFilterChain.class); final SecurityContextImpl securityContext = new SecurityContextImpl(mock(Authentication.class)); - this.exceptionRule.expect(IllegalArgumentException.class); - this.exceptionRule.expectMessage("The userName can not be null."); - this.switchUserWebFilter.filter(exchange, chain) - .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext))) - .block(); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.switchUserWebFilter.filter(exchange, chain) + .subscriberContext( + ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext))) + .block()) + .withMessage("The userName can not be null."); verifyNoInteractions(chain); } @@ -219,10 +216,12 @@ public class SwitchUserWebFilterTests { final SecurityContextImpl securityContext = new SecurityContextImpl(mock(Authentication.class)); final UserDetails switchUserDetails = switchUserDetails(targetUsername, false); given(this.userDetailsService.findByUsername(any(String.class))).willReturn(Mono.just(switchUserDetails)); - this.exceptionRule.expect(DisabledException.class); - this.switchUserWebFilter.filter(exchange, chain) - .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext))) - .block(); + assertThatExceptionOfType(DisabledException.class) + .isThrownBy( + () -> this.switchUserWebFilter.filter(exchange, chain) + .subscriberContext( + ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext))) + .block()); verifyNoInteractions(chain); } @@ -265,11 +264,12 @@ public class SwitchUserWebFilterTests { "origCredentials"); final WebFilterChain chain = mock(WebFilterChain.class); final SecurityContextImpl securityContext = new SecurityContextImpl(originalAuthentication); - this.exceptionRule.expect(AuthenticationCredentialsNotFoundException.class); - this.exceptionRule.expectMessage("Could not find original Authentication object"); - this.switchUserWebFilter.filter(exchange, chain) - .subscriberContext(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext))) - .block(); + assertThatExceptionOfType(AuthenticationCredentialsNotFoundException.class) + .isThrownBy(() -> this.switchUserWebFilter.filter(exchange, chain) + .subscriberContext( + ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext))) + .block()) + .withMessage("Could not find original Authentication object"); verifyNoInteractions(chain); } @@ -278,34 +278,35 @@ public class SwitchUserWebFilterTests { final MockServerWebExchange exchange = MockServerWebExchange .from(MockServerHttpRequest.post("/logout/impersonate")); final WebFilterChain chain = mock(WebFilterChain.class); - this.exceptionRule.expect(AuthenticationCredentialsNotFoundException.class); - this.exceptionRule.expectMessage("No current user associated with this request"); - this.switchUserWebFilter.filter(exchange, chain).block(); + assertThatExceptionOfType(AuthenticationCredentialsNotFoundException.class) + .isThrownBy(() -> this.switchUserWebFilter.filter(exchange, chain).block()) + .withMessage("No current user associated with this request"); verifyNoInteractions(chain); } @Test public void constructorUserDetailsServiceRequired() { - this.exceptionRule.expect(IllegalArgumentException.class); - this.exceptionRule.expectMessage("userDetailsService must be specified"); - this.switchUserWebFilter = new SwitchUserWebFilter(null, mock(ServerAuthenticationSuccessHandler.class), - mock(ServerAuthenticationFailureHandler.class)); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.switchUserWebFilter = new SwitchUserWebFilter(null, + mock(ServerAuthenticationSuccessHandler.class), mock(ServerAuthenticationFailureHandler.class))) + .withMessage("userDetailsService must be specified"); } @Test public void constructorServerAuthenticationSuccessHandlerRequired() { - this.exceptionRule.expect(IllegalArgumentException.class); - this.exceptionRule.expectMessage("successHandler must be specified"); - this.switchUserWebFilter = new SwitchUserWebFilter(mock(ReactiveUserDetailsService.class), null, - mock(ServerAuthenticationFailureHandler.class)); + assertThatIllegalArgumentException() + .isThrownBy( + () -> this.switchUserWebFilter = new SwitchUserWebFilter(mock(ReactiveUserDetailsService.class), + null, mock(ServerAuthenticationFailureHandler.class))) + .withMessage("successHandler must be specified"); } @Test public void constructorSuccessTargetUrlRequired() { - this.exceptionRule.expect(IllegalArgumentException.class); - this.exceptionRule.expectMessage("successTargetUrl must be specified"); - this.switchUserWebFilter = new SwitchUserWebFilter(mock(ReactiveUserDetailsService.class), null, - "failure/target/url"); + assertThatIllegalArgumentException().isThrownBy( + () -> this.switchUserWebFilter = new SwitchUserWebFilter(mock(ReactiveUserDetailsService.class), null, + "failure/target/url")) + .withMessage("successTargetUrl must be specified"); } @Test @@ -336,10 +337,9 @@ public class SwitchUserWebFilterTests { @Test public void setSecurityContextRepositoryWhenNullThenThrowException() { - this.exceptionRule.expect(IllegalArgumentException.class); - this.exceptionRule.expectMessage("securityContextRepository cannot be null"); - this.switchUserWebFilter.setSecurityContextRepository(null); - fail("Test should fail with exception"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.switchUserWebFilter.setSecurityContextRepository(null)) + .withMessage("securityContextRepository cannot be null"); } @Test @@ -357,18 +357,14 @@ public class SwitchUserWebFilterTests { @Test public void setExitUserUrlWhenNullThenThrowException() { - this.exceptionRule.expect(IllegalArgumentException.class); - this.exceptionRule.expectMessage("exitUserUrl cannot be empty and must be a valid redirect URL"); - this.switchUserWebFilter.setExitUserUrl(null); - fail("Test should fail with exception"); + assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setExitUserUrl(null)) + .withMessage("exitUserUrl cannot be empty and must be a valid redirect URL"); } @Test public void setExitUserUrlWhenInvalidUrlThenThrowException() { - this.exceptionRule.expect(IllegalArgumentException.class); - this.exceptionRule.expectMessage("exitUserUrl cannot be empty and must be a valid redirect URL"); - this.switchUserWebFilter.setExitUserUrl("wrongUrl"); - fail("Test should fail with exception"); + assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setExitUserUrl("wrongUrl")) + .withMessage("exitUserUrl cannot be empty and must be a valid redirect URL"); } @Test @@ -387,10 +383,8 @@ public class SwitchUserWebFilterTests { @Test public void setExitUserMatcherWhenNullThenThrowException() { - this.exceptionRule.expect(IllegalArgumentException.class); - this.exceptionRule.expectMessage("exitUserMatcher cannot be null"); - this.switchUserWebFilter.setExitUserMatcher(null); - fail("Test should fail with exception"); + assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setExitUserMatcher(null)) + .withMessage("exitUserMatcher cannot be null"); } @Test @@ -410,18 +404,14 @@ public class SwitchUserWebFilterTests { @Test public void setSwitchUserUrlWhenNullThenThrowException() { - this.exceptionRule.expect(IllegalArgumentException.class); - this.exceptionRule.expectMessage("switchUserUrl cannot be empty and must be a valid redirect URL"); - this.switchUserWebFilter.setSwitchUserUrl(null); - fail("Test should fail with exception"); + assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setSwitchUserUrl(null)) + .withMessage("switchUserUrl cannot be empty and must be a valid redirect URL"); } @Test public void setSwitchUserUrlWhenInvalidThenThrowException() { - this.exceptionRule.expect(IllegalArgumentException.class); - this.exceptionRule.expectMessage("switchUserUrl cannot be empty and must be a valid redirect URL"); - this.switchUserWebFilter.setSwitchUserUrl("wrongUrl"); - fail("Test should fail with exception"); + assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setSwitchUserUrl("wrongUrl")) + .withMessage("switchUserUrl cannot be empty and must be a valid redirect URL"); } @Test @@ -440,10 +430,8 @@ public class SwitchUserWebFilterTests { @Test public void setSwitchUserMatcherWhenNullThenThrowException() { - this.exceptionRule.expect(IllegalArgumentException.class); - this.exceptionRule.expectMessage("switchUserMatcher cannot be null"); - this.switchUserWebFilter.setSwitchUserMatcher(null); - fail("Test should fail with exception"); + assertThatIllegalArgumentException().isThrownBy(() -> this.switchUserWebFilter.setSwitchUserMatcher(null)) + .withMessage("switchUserMatcher cannot be null"); } @Test