Browse Source

Polish spring-security-test main code

Manually polish `spring-security-test` following the formatting
and checkstyle fixes.

Issue gh-8945
pull/8983/head
Phillip Webb 5 years ago committed by Rob Winch
parent
commit
ef951bae90
  1. 9
      test/src/main/java/org/springframework/security/test/context/TestSecurityContextHolder.java
  2. 8
      test/src/main/java/org/springframework/security/test/context/support/ReactorContextTestExecutionListener.java
  3. 1
      test/src/main/java/org/springframework/security/test/context/support/TestExecutionEvent.java
  4. 13
      test/src/main/java/org/springframework/security/test/context/support/WithMockUserSecurityContextFactory.java
  5. 1
      test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java
  6. 2
      test/src/main/java/org/springframework/security/test/context/support/WithUserDetailsSecurityContextFactory.java
  7. 36
      test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java
  8. 33
      test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuilders.java
  9. 76
      test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java
  10. 26
      test/src/main/java/org/springframework/security/test/web/servlet/response/SecurityMockMvcResultMatchers.java
  11. 19
      test/src/main/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurer.java
  12. 23
      test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java

9
test/src/main/java/org/springframework/security/test/context/TestSecurityContextHolder.java

@ -56,12 +56,14 @@ import org.springframework.util.Assert;
* @author Rob Winch * @author Rob Winch
* @author Tadaya Tsuyukubo * @author Tadaya Tsuyukubo
* @since 4.0 * @since 4.0
*
*/ */
public final class TestSecurityContextHolder { public final class TestSecurityContextHolder {
private static final ThreadLocal<SecurityContext> contextHolder = new ThreadLocal<>(); private static final ThreadLocal<SecurityContext> contextHolder = new ThreadLocal<>();
private TestSecurityContextHolder() {
}
/** /**
* Clears the {@link SecurityContext} from {@link TestSecurityContextHolder} and * Clears the {@link SecurityContext} from {@link TestSecurityContextHolder} and
* {@link SecurityContextHolder}. * {@link SecurityContextHolder}.
@ -77,12 +79,10 @@ public final class TestSecurityContextHolder {
*/ */
public static SecurityContext getContext() { public static SecurityContext getContext() {
SecurityContext ctx = contextHolder.get(); SecurityContext ctx = contextHolder.get();
if (ctx == null) { if (ctx == null) {
ctx = getDefaultContext(); ctx = getDefaultContext();
contextHolder.set(ctx); contextHolder.set(ctx);
} }
return ctx; return ctx;
} }
@ -120,7 +120,4 @@ public final class TestSecurityContextHolder {
return SecurityContextHolder.getContext(); return SecurityContextHolder.getContext();
} }
private TestSecurityContextHolder() {
}
} }

8
test/src/main/java/org/springframework/security/test/context/support/ReactorContextTestExecutionListener.java

@ -52,9 +52,11 @@ public class ReactorContextTestExecutionListener extends DelegatingTestExecution
} }
private static TestExecutionListener createDelegate() { private static TestExecutionListener createDelegate() {
return ClassUtils.isPresent(HOOKS_CLASS_NAME, ReactorContextTestExecutionListener.class.getClassLoader()) if (!ClassUtils.isPresent(HOOKS_CLASS_NAME, ReactorContextTestExecutionListener.class.getClassLoader())) {
? new DelegateTestExecutionListener() : new AbstractTestExecutionListener() { return new AbstractTestExecutionListener() {
}; };
}
return new DelegateTestExecutionListener();
} }
/** /**

1
test/src/main/java/org/springframework/security/test/context/support/TestExecutionEvent.java

@ -33,6 +33,7 @@ public enum TestExecutionEvent {
* event. * event.
*/ */
TEST_METHOD, TEST_METHOD,
/** /**
* Associated to * Associated to
* {@link org.springframework.test.context.TestExecutionListener#beforeTestExecution(TestContext)} * {@link org.springframework.test.context.TestExecutionListener#beforeTestExecution(TestContext)}

13
test/src/main/java/org/springframework/security/test/context/support/WithMockUserSecurityContextFactory.java

@ -27,6 +27,7 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.User;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
/** /**
@ -41,21 +42,14 @@ final class WithMockUserSecurityContextFactory implements WithSecurityContextFac
@Override @Override
public SecurityContext createSecurityContext(WithMockUser withUser) { public SecurityContext createSecurityContext(WithMockUser withUser) {
String username = StringUtils.hasLength(withUser.username()) ? withUser.username() : withUser.value(); String username = StringUtils.hasLength(withUser.username()) ? withUser.username() : withUser.value();
if (username == null) { Assert.notNull(username, () -> withUser + " cannot have null username on both username and value properties");
throw new IllegalArgumentException(
withUser + " cannot have null username on both username and value properties");
}
List<GrantedAuthority> grantedAuthorities = new ArrayList<>(); List<GrantedAuthority> grantedAuthorities = new ArrayList<>();
for (String authority : withUser.authorities()) { for (String authority : withUser.authorities()) {
grantedAuthorities.add(new SimpleGrantedAuthority(authority)); grantedAuthorities.add(new SimpleGrantedAuthority(authority));
} }
if (grantedAuthorities.isEmpty()) { if (grantedAuthorities.isEmpty()) {
for (String role : withUser.roles()) { for (String role : withUser.roles()) {
if (role.startsWith("ROLE_")) { Assert.isTrue(!role.startsWith("ROLE_"), () -> "roles cannot start with ROLE_ Got " + role);
throw new IllegalArgumentException("roles cannot start with ROLE_ Got " + role);
}
grantedAuthorities.add(new SimpleGrantedAuthority("ROLE_" + role)); grantedAuthorities.add(new SimpleGrantedAuthority("ROLE_" + role));
} }
} }
@ -63,7 +57,6 @@ final class WithMockUserSecurityContextFactory implements WithSecurityContextFac
throw new IllegalStateException("You cannot define roles attribute " + Arrays.asList(withUser.roles()) throw new IllegalStateException("You cannot define roles attribute " + Arrays.asList(withUser.roles())
+ " with authorities attribute " + Arrays.asList(withUser.authorities())); + " with authorities attribute " + Arrays.asList(withUser.authorities()));
} }
User principal = new User(username, withUser.password(), true, true, true, true, grantedAuthorities); User principal = new User(username, withUser.password(), true, true, true, true, grantedAuthorities);
Authentication authentication = new UsernamePasswordAuthenticationToken(principal, principal.getPassword(), Authentication authentication = new UsernamePasswordAuthenticationToken(principal, principal.getPassword(),
principal.getAuthorities()); principal.getAuthorities());

1
test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java

@ -68,7 +68,6 @@ public class WithSecurityContextTestExecutionListener extends AbstractTestExecut
if (testSecurityContext == null) { if (testSecurityContext == null) {
return; return;
} }
Supplier<SecurityContext> supplier = testSecurityContext.getSecurityContextSupplier(); Supplier<SecurityContext> supplier = testSecurityContext.getSecurityContextSupplier();
if (testSecurityContext.getTestExecutionEvent() == TestExecutionEvent.TEST_METHOD) { if (testSecurityContext.getTestExecutionEvent() == TestExecutionEvent.TEST_METHOD) {
TestSecurityContextHolder.setContext(supplier.get()); TestSecurityContextHolder.setContext(supplier.get());

2
test/src/main/java/org/springframework/security/test/context/support/WithUserDetailsSecurityContextFactory.java

@ -84,7 +84,7 @@ final class WithUserDetailsSecurityContextFactory implements WithSecurityContext
: this.beans.getBean(ReactiveUserDetailsService.class); : this.beans.getBean(ReactiveUserDetailsService.class);
return new ReactiveUserDetailsServiceAdapter(reactiveUserDetailsService); return new ReactiveUserDetailsServiceAdapter(reactiveUserDetailsService);
} }
catch (NoSuchBeanDefinitionException | BeanNotOfRequiredTypeException notReactive) { catch (NoSuchBeanDefinitionException | BeanNotOfRequiredTypeException ex) {
return null; return null;
} }
} }

36
test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java

@ -108,10 +108,12 @@ public final class SecurityMockServerConfigurers {
*/ */
public static MockServerConfigurer springSecurity() { public static MockServerConfigurer springSecurity() {
return new MockServerConfigurer() { return new MockServerConfigurer() {
@Override @Override
public void beforeServerCreated(WebHttpHandlerBuilder builder) { public void beforeServerCreated(WebHttpHandlerBuilder builder) {
builder.filters((filters) -> filters.add(0, new MutatorFilter())); builder.filters((filters) -> filters.add(0, new MutatorFilter()));
} }
}; };
} }
@ -992,26 +994,22 @@ public final class SecurityMockServerConfigurers {
} }
private Collection<GrantedAuthority> getAuthorities() { private Collection<GrantedAuthority> getAuthorities() {
if (this.authorities == null) { if (this.authorities != null) {
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OidcUserAuthority(getOidcIdToken(), getOidcUserInfo()));
for (String authority : this.accessToken.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
return authorities;
}
else {
return this.authorities; return this.authorities;
} }
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OidcUserAuthority(getOidcIdToken(), getOidcUserInfo()));
for (String authority : this.accessToken.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
return authorities;
} }
private OidcIdToken getOidcIdToken() { private OidcIdToken getOidcIdToken() {
if (this.idToken == null) { if (this.idToken != null) {
return new OidcIdToken("id-token", null, null, Collections.singletonMap(IdTokenClaimNames.SUB, "user"));
}
else {
return this.idToken; return this.idToken;
} }
return new OidcIdToken("id-token", null, null, Collections.singletonMap(IdTokenClaimNames.SUB, "user"));
} }
private OidcUserInfo getOidcUserInfo() { private OidcUserInfo getOidcUserInfo() {
@ -1071,7 +1069,6 @@ public final class SecurityMockServerConfigurers {
*/ */
public OAuth2ClientMutator clientRegistration( public OAuth2ClientMutator clientRegistration(
Consumer<ClientRegistration.Builder> clientRegistrationConfigurer) { Consumer<ClientRegistration.Builder> clientRegistrationConfigurer) {
ClientRegistration.Builder builder = clientRegistrationBuilder(); ClientRegistration.Builder builder = clientRegistrationBuilder();
clientRegistrationConfigurer.accept(builder); clientRegistrationConfigurer.accept(builder);
this.clientRegistration = builder.build(); this.clientRegistration = builder.build();
@ -1108,7 +1105,6 @@ public final class SecurityMockServerConfigurers {
@Override @Override
public void afterConfigureAdded(WebTestClient.MockServerSpec<?> serverSpec) { public void afterConfigureAdded(WebTestClient.MockServerSpec<?> serverSpec) {
} }
@Override @Override
@ -1134,10 +1130,8 @@ public final class SecurityMockServerConfigurers {
} }
private OAuth2AuthorizedClient getClient() { private OAuth2AuthorizedClient getClient() {
if (this.clientRegistration == null) { Assert.notNull(this.clientRegistration,
throw new IllegalArgumentException( "Please specify a ClientRegistration via one of the clientRegistration methods");
"Please specify a ClientRegistration via one " + "of the clientRegistration methods");
}
return new OAuth2AuthorizedClient(this.clientRegistration, this.principalName, this.accessToken); return new OAuth2AuthorizedClient(this.clientRegistration, this.principalName, this.accessToken);
} }
@ -1173,9 +1167,7 @@ public final class SecurityMockServerConfigurers {
OAuth2AuthorizedClient client = exchange.getAttribute(TOKEN_ATTR_NAME); OAuth2AuthorizedClient client = exchange.getAttribute(TOKEN_ATTR_NAME);
return Mono.just(client); return Mono.just(client);
} }
else { return this.delegate.authorize(authorizeRequest);
return this.delegate.authorize(authorizeRequest);
}
} }
static void enable(ServerWebExchange exchange) { static void enable(ServerWebExchange exchange) {

33
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuilders.java

@ -36,10 +36,12 @@ import static org.springframework.test.web.servlet.request.MockMvcRequestBuilder
* *
* @author Rob Winch * @author Rob Winch
* @since 4.0 * @since 4.0
*
*/ */
public final class SecurityMockMvcRequestBuilders { public final class SecurityMockMvcRequestBuilders {
private SecurityMockMvcRequestBuilders() {
}
/** /**
* Creates a request (including any necessary {@link CsrfToken}) that will submit a * Creates a request (including any necessary {@link CsrfToken}) that will submit a
* form based login to POST "/login". * form based login to POST "/login".
@ -91,18 +93,18 @@ public final class SecurityMockMvcRequestBuilders {
private Mergeable parent; private Mergeable parent;
private LogoutRequestBuilder() {
}
@Override @Override
public MockHttpServletRequest buildRequest(ServletContext servletContext) { public MockHttpServletRequest buildRequest(ServletContext servletContext) {
MockHttpServletRequestBuilder logoutRequest = post(this.logoutUrl).accept(MediaType.TEXT_HTML, MockHttpServletRequestBuilder logoutRequest = post(this.logoutUrl).accept(MediaType.TEXT_HTML,
MediaType.ALL); MediaType.ALL);
if (this.parent != null) { if (this.parent != null) {
logoutRequest = (MockHttpServletRequestBuilder) logoutRequest.merge(this.parent); logoutRequest = (MockHttpServletRequestBuilder) logoutRequest.merge(this.parent);
} }
MockHttpServletRequest request = logoutRequest.buildRequest(servletContext); MockHttpServletRequest request = logoutRequest.buildRequest(servletContext);
logoutRequest.postProcessRequest(request); logoutRequest.postProcessRequest(request);
return this.postProcessor.postProcessRequest(request); return this.postProcessor.postProcessRequest(request);
} }
@ -141,12 +143,7 @@ public final class SecurityMockMvcRequestBuilders {
this.parent = (Mergeable) parent; this.parent = (Mergeable) parent;
return this; return this;
} }
else { throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
}
}
private LogoutRequestBuilder() {
} }
} }
@ -175,18 +172,18 @@ public final class SecurityMockMvcRequestBuilders {
private RequestPostProcessor postProcessor = csrf(); private RequestPostProcessor postProcessor = csrf();
private FormLoginRequestBuilder() {
}
@Override @Override
public MockHttpServletRequest buildRequest(ServletContext servletContext) { public MockHttpServletRequest buildRequest(ServletContext servletContext) {
MockHttpServletRequestBuilder loginRequest = post(this.loginProcessingUrl).accept(this.acceptMediaType) MockHttpServletRequestBuilder loginRequest = post(this.loginProcessingUrl).accept(this.acceptMediaType)
.param(this.usernameParam, this.username).param(this.passwordParam, this.password); .param(this.usernameParam, this.username).param(this.passwordParam, this.password);
if (this.parent != null) { if (this.parent != null) {
loginRequest = (MockHttpServletRequestBuilder) loginRequest.merge(this.parent); loginRequest = (MockHttpServletRequestBuilder) loginRequest.merge(this.parent);
} }
MockHttpServletRequest request = loginRequest.buildRequest(servletContext); MockHttpServletRequest request = loginRequest.buildRequest(servletContext);
loginRequest.postProcessRequest(request); loginRequest.postProcessRequest(request);
return this.postProcessor.postProcessRequest(request); return this.postProcessor.postProcessRequest(request);
} }
@ -305,17 +302,9 @@ public final class SecurityMockMvcRequestBuilders {
this.parent = (Mergeable) parent; this.parent = (Mergeable) parent;
return this; return this;
} }
else { throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
throw new IllegalArgumentException("Cannot merge with [" + parent.getClass().getName() + "]");
}
} }
private FormLoginRequestBuilder() {
}
}
private SecurityMockMvcRequestBuilders() {
} }
} }

76
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

@ -116,6 +116,9 @@ import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandl
*/ */
public final class SecurityMockMvcRequestPostProcessors { public final class SecurityMockMvcRequestPostProcessors {
private SecurityMockMvcRequestPostProcessors() {
}
/** /**
* Creates a DigestRequestPostProcessor that enables easily adding digest based * Creates a DigestRequestPostProcessor that enables easily adding digest based
* authentication to a request. * authentication to a request.
@ -634,7 +637,6 @@ public final class SecurityMockMvcRequestPostProcessors {
String toDigest = expiryTime + ":" + "key"; String toDigest = expiryTime + ":" + "key";
String signatureValue = md5Hex(toDigest); String signatureValue = md5Hex(toDigest);
String nonceValue = expiryTime + ":" + signatureValue; String nonceValue = expiryTime + ":" + signatureValue;
return new String(Base64.getEncoder().encode(nonceValue.getBytes())); return new String(Base64.getEncoder().encode(nonceValue.getBytes()));
} }
@ -649,7 +651,6 @@ public final class SecurityMockMvcRequestPostProcessors {
@Override @Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) { public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
request.addHeader("Authorization", createAuthorizationHeader(request)); request.addHeader("Authorization", createAuthorizationHeader(request));
return request; return request;
} }
@ -676,28 +677,19 @@ public final class SecurityMockMvcRequestPostProcessors {
String a1Md5 = encodePasswordInA1Format(username, realm, password); String a1Md5 = encodePasswordInA1Format(username, realm, password);
String a2 = httpMethod + ":" + uri; String a2 = httpMethod + ":" + uri;
String a2Md5 = md5Hex(a2); String a2Md5 = md5Hex(a2);
String digest;
if (qop == null) { if (qop == null) {
// as per RFC 2069 compliant clients (also reaffirmed by RFC 2617) // as per RFC 2069 compliant clients (also reaffirmed by RFC 2617)
digest = a1Md5 + ":" + nonce + ":" + a2Md5; return md5Hex(a1Md5 + ":" + nonce + ":" + a2Md5);
} }
else if ("auth".equals(qop)) { if ("auth".equals(qop)) {
// As per RFC 2617 compliant clients // As per RFC 2617 compliant clients
digest = a1Md5 + ":" + nonce + ":" + nc + ":" + cnonce + ":" + qop + ":" + a2Md5; return md5Hex(a1Md5 + ":" + nonce + ":" + nc + ":" + cnonce + ":" + qop + ":" + a2Md5);
} }
else { throw new IllegalArgumentException("This method does not support a qop: '" + qop + "'");
throw new IllegalArgumentException("This method does not support a qop: '" + qop + "'");
}
return md5Hex(digest);
} }
static String encodePasswordInA1Format(String username, String realm, String password) { static String encodePasswordInA1Format(String username, String realm, String password) {
String a1 = username + ":" + realm + ":" + password; return md5Hex(username + ":" + realm + ":" + password);
return md5Hex(a1);
} }
private static String md5Hex(String a2) { private static String md5Hex(String a2) {
@ -736,15 +728,11 @@ public final class SecurityMockMvcRequestPostProcessors {
securityContextRepository = new TestSecurityContextRepository(securityContextRepository); securityContextRepository = new TestSecurityContextRepository(securityContextRepository);
WebTestUtils.setSecurityContextRepository(request, securityContextRepository); WebTestUtils.setSecurityContextRepository(request, securityContextRepository);
} }
HttpServletResponse response = new MockHttpServletResponse(); HttpServletResponse response = new MockHttpServletResponse();
HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(request, response); HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(request, response);
securityContextRepository.loadContext(requestResponseHolder); securityContextRepository.loadContext(requestResponseHolder);
request = requestResponseHolder.getRequest(); request = requestResponseHolder.getRequest();
response = requestResponseHolder.getResponse(); response = requestResponseHolder.getResponse();
securityContextRepository.saveContext(securityContext, request, response); securityContextRepository.saveContext(securityContext, request, response);
} }
@ -812,12 +800,10 @@ public final class SecurityMockMvcRequestPostProcessors {
if (existingContext != null) { if (existingContext != null) {
return request; return request;
} }
SecurityContext context = TestSecurityContextHolder.getContext(); SecurityContext context = TestSecurityContextHolder.getContext();
if (!this.EMPTY.equals(context)) { if (!this.EMPTY.equals(context)) {
save(context, request); save(context, request);
} }
return request; return request;
} }
@ -889,7 +875,6 @@ public final class SecurityMockMvcRequestPostProcessors {
UserDetailsRequestPostProcessor(UserDetails user) { UserDetailsRequestPostProcessor(UserDetails user) {
Authentication token = new UsernamePasswordAuthenticationToken(user, user.getPassword(), Authentication token = new UsernamePasswordAuthenticationToken(user, user.getPassword(),
user.getAuthorities()); user.getAuthorities());
this.delegate = new AuthenticationRequestPostProcessor(token); this.delegate = new AuthenticationRequestPostProcessor(token);
} }
@ -948,13 +933,9 @@ public final class SecurityMockMvcRequestPostProcessors {
public UserRequestPostProcessor roles(String... roles) { public UserRequestPostProcessor roles(String... roles) {
List<GrantedAuthority> authorities = new ArrayList<>(roles.length); List<GrantedAuthority> authorities = new ArrayList<>(roles.length);
for (String role : roles) { for (String role : roles) {
if (role.startsWith(ROLE_PREFIX)) { Assert.isTrue(!role.startsWith(ROLE_PREFIX), () -> "Role should not start with " + ROLE_PREFIX
throw new IllegalArgumentException("Role should not start with " + ROLE_PREFIX + " since this method automatically prefixes with this value. Got " + role);
+ " since this method automatically prefixes with this value. Got " + role); authorities.add(new SimpleGrantedAuthority(ROLE_PREFIX + role));
}
else {
authorities.add(new SimpleGrantedAuthority(ROLE_PREFIX + role));
}
} }
this.authorities = authorities; this.authorities = authorities;
return this; return this;
@ -1027,8 +1008,7 @@ public final class SecurityMockMvcRequestPostProcessors {
private String headerValue; private String headerValue;
private HttpBasicRequestPostProcessor(String username, String password) { private HttpBasicRequestPostProcessor(String username, String password) {
byte[] toEncode; byte[] toEncode = (username + ":" + password).getBytes(StandardCharsets.UTF_8);
toEncode = (username + ":" + password).getBytes(StandardCharsets.UTF_8);
this.headerValue = "Basic " + new String(Base64.getEncoder().encode(toEncode)); this.headerValue = "Basic " + new String(Base64.getEncoder().encode(toEncode));
} }
@ -1356,7 +1336,6 @@ public final class SecurityMockMvcRequestPostProcessors {
OAuth2User oauth2User = this.oauth2User.get(); OAuth2User oauth2User = this.oauth2User.get();
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(oauth2User, oauth2User.getAuthorities(), OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(oauth2User, oauth2User.getAuthorities(),
this.clientRegistration.getRegistrationId()); this.clientRegistration.getRegistrationId());
request = new AuthenticationRequestPostProcessor(token).postProcessRequest(request); request = new AuthenticationRequestPostProcessor(token).postProcessRequest(request);
return new OAuth2ClientRequestPostProcessor().clientRegistration(this.clientRegistration) return new OAuth2ClientRequestPostProcessor().clientRegistration(this.clientRegistration)
.principalName(oauth2User.getName()).accessToken(this.accessToken).postProcessRequest(request); .principalName(oauth2User.getName()).accessToken(this.accessToken).postProcessRequest(request);
@ -1504,26 +1483,22 @@ public final class SecurityMockMvcRequestPostProcessors {
} }
private Collection<GrantedAuthority> getAuthorities() { private Collection<GrantedAuthority> getAuthorities() {
if (this.authorities == null) { if (this.authorities != null) {
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OidcUserAuthority(getOidcIdToken(), getOidcUserInfo()));
for (String authority : this.accessToken.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
return authorities;
}
else {
return this.authorities; return this.authorities;
} }
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OidcUserAuthority(getOidcIdToken(), getOidcUserInfo()));
for (String authority : this.accessToken.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
return authorities;
} }
private OidcIdToken getOidcIdToken() { private OidcIdToken getOidcIdToken() {
if (this.idToken == null) { if (this.idToken != null) {
return new OidcIdToken("id-token", null, null, Collections.singletonMap(IdTokenClaimNames.SUB, "user"));
}
else {
return this.idToken; return this.idToken;
} }
return new OidcIdToken("id-token", null, null, Collections.singletonMap(IdTokenClaimNames.SUB, "user"));
} }
private OidcUserInfo getOidcUserInfo() { private OidcUserInfo getOidcUserInfo() {
@ -1577,7 +1552,6 @@ public final class SecurityMockMvcRequestPostProcessors {
*/ */
public OAuth2ClientRequestPostProcessor clientRegistration( public OAuth2ClientRequestPostProcessor clientRegistration(
Consumer<ClientRegistration.Builder> clientRegistrationConfigurer) { Consumer<ClientRegistration.Builder> clientRegistrationConfigurer) {
ClientRegistration.Builder builder = clientRegistrationBuilder(); ClientRegistration.Builder builder = clientRegistrationBuilder();
clientRegistrationConfigurer.accept(builder); clientRegistrationConfigurer.accept(builder);
this.clientRegistration = builder.build(); this.clientRegistration = builder.build();
@ -1613,7 +1587,6 @@ public final class SecurityMockMvcRequestPostProcessors {
} }
OAuth2AuthorizedClient client = new OAuth2AuthorizedClient(this.clientRegistration, this.principalName, OAuth2AuthorizedClient client = new OAuth2AuthorizedClient(this.clientRegistration, this.principalName,
this.accessToken); this.accessToken);
OAuth2AuthorizedClientManager authorizationClientManager = OAuth2ClientServletTestUtils OAuth2AuthorizedClientManager authorizationClientManager = OAuth2ClientServletTestUtils
.getOAuth2AuthorizedClientManager(request); .getOAuth2AuthorizedClientManager(request);
if (!(authorizationClientManager instanceof TestOAuth2AuthorizedClientManager)) { if (!(authorizationClientManager instanceof TestOAuth2AuthorizedClientManager)) {
@ -1654,9 +1627,7 @@ public final class SecurityMockMvcRequestPostProcessors {
if (isEnabled(request)) { if (isEnabled(request)) {
return (OAuth2AuthorizedClient) request.getAttribute(TOKEN_ATTR_NAME); return (OAuth2AuthorizedClient) request.getAttribute(TOKEN_ATTR_NAME);
} }
else { return this.delegate.authorize(authorizeRequest);
return this.delegate.authorize(authorizeRequest);
}
} }
static void enable(HttpServletRequest request) { static void enable(HttpServletRequest request) {
@ -1762,7 +1733,4 @@ public final class SecurityMockMvcRequestPostProcessors {
} }
private SecurityMockMvcRequestPostProcessors() {
}
} }

26
test/src/main/java/org/springframework/security/test/web/servlet/response/SecurityMockMvcResultMatchers.java

@ -43,6 +43,9 @@ import org.springframework.test.web.servlet.ResultMatcher;
*/ */
public final class SecurityMockMvcResultMatchers { public final class SecurityMockMvcResultMatchers {
private SecurityMockMvcResultMatchers() {
}
/** /**
* {@link ResultMatcher} that verifies that a specified user is authenticated. * {@link ResultMatcher} that verifies that a specified user is authenticated.
* @return the {@link AuthenticatedMatcher} to use * @return the {@link AuthenticatedMatcher} to use
@ -90,29 +93,26 @@ public final class SecurityMockMvcResultMatchers {
private Consumer<Authentication> assertAuthentication; private Consumer<Authentication> assertAuthentication;
AuthenticatedMatcher() {
}
@Override @Override
public void match(MvcResult result) { public void match(MvcResult result) {
SecurityContext context = load(result); SecurityContext context = load(result);
Authentication auth = context.getAuthentication(); Authentication auth = context.getAuthentication();
AssertionErrors.assertTrue("Authentication should not be null", auth != null); AssertionErrors.assertTrue("Authentication should not be null", auth != null);
if (this.assertAuthentication != null) { if (this.assertAuthentication != null) {
this.assertAuthentication.accept(auth); this.assertAuthentication.accept(auth);
} }
if (this.expectedContext != null) { if (this.expectedContext != null) {
AssertionErrors.assertEquals(this.expectedContext + " does not equal " + context, this.expectedContext, AssertionErrors.assertEquals(this.expectedContext + " does not equal " + context, this.expectedContext,
context); context);
} }
if (this.expectedAuthentication != null) { if (this.expectedAuthentication != null) {
AssertionErrors.assertEquals( AssertionErrors.assertEquals(
this.expectedAuthentication + " does not equal " + context.getAuthentication(), this.expectedAuthentication + " does not equal " + context.getAuthentication(),
this.expectedAuthentication, context.getAuthentication()); this.expectedAuthentication, context.getAuthentication());
} }
if (this.expectedAuthenticationPrincipal != null) { if (this.expectedAuthenticationPrincipal != null) {
AssertionErrors.assertTrue("Authentication cannot be null", context.getAuthentication() != null); AssertionErrors.assertTrue("Authentication cannot be null", context.getAuthentication() != null);
AssertionErrors.assertEquals( AssertionErrors.assertEquals(
@ -120,14 +120,12 @@ public final class SecurityMockMvcResultMatchers {
+ context.getAuthentication().getPrincipal(), + context.getAuthentication().getPrincipal(),
this.expectedAuthenticationPrincipal, context.getAuthentication().getPrincipal()); this.expectedAuthenticationPrincipal, context.getAuthentication().getPrincipal());
} }
if (this.expectedAuthenticationName != null) { if (this.expectedAuthenticationName != null) {
AssertionErrors.assertTrue("Authentication cannot be null", auth != null); AssertionErrors.assertTrue("Authentication cannot be null", auth != null);
String name = auth.getName(); String name = auth.getName();
AssertionErrors.assertEquals(this.expectedAuthenticationName + " does not equal " + name, AssertionErrors.assertEquals(this.expectedAuthenticationName + " does not equal " + name,
this.expectedAuthenticationName, name); this.expectedAuthenticationName, name);
} }
if (this.expectedGrantedAuthorities != null) { if (this.expectedGrantedAuthorities != null) {
AssertionErrors.assertTrue("Authentication cannot be null", auth != null); AssertionErrors.assertTrue("Authentication cannot be null", auth != null);
Collection<? extends GrantedAuthority> authorities = auth.getAuthorities(); Collection<? extends GrantedAuthority> authorities = auth.getAuthorities();
@ -222,9 +220,6 @@ public final class SecurityMockMvcResultMatchers {
return withAuthorities(authorities); return withAuthorities(authorities);
} }
AuthenticatedMatcher() {
}
} }
/** /**
@ -238,6 +233,9 @@ public final class SecurityMockMvcResultMatchers {
private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl(); private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl();
private UnAuthenticatedMatcher() {
}
@Override @Override
public void match(MvcResult result) { public void match(MvcResult result) {
SecurityContext context = load(result); SecurityContext context = load(result);
@ -247,12 +245,6 @@ public final class SecurityMockMvcResultMatchers {
authentication == null || this.trustResolver.isAnonymous(authentication)); authentication == null || this.trustResolver.isAnonymous(authentication));
} }
private UnAuthenticatedMatcher() {
}
}
private SecurityMockMvcResultMatchers() {
} }
} }

19
test/src/main/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurer.java

@ -29,6 +29,7 @@ import org.springframework.security.config.BeanIds;
import org.springframework.test.web.servlet.request.RequestPostProcessor; import org.springframework.test.web.servlet.request.RequestPostProcessor;
import org.springframework.test.web.servlet.setup.ConfigurableMockMvcBuilder; import org.springframework.test.web.servlet.setup.ConfigurableMockMvcBuilder;
import org.springframework.test.web.servlet.setup.MockMvcConfigurerAdapter; import org.springframework.test.web.servlet.setup.MockMvcConfigurerAdapter;
import org.springframework.util.Assert;
import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.WebApplicationContext;
import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.testSecurityContext; import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.testSecurityContext;
@ -72,15 +73,11 @@ final class SecurityMockMvcConfigurer extends MockMvcConfigurerAdapter {
if (getSpringSecurityFilterChain() == null && context.containsBean(securityBeanId)) { if (getSpringSecurityFilterChain() == null && context.containsBean(securityBeanId)) {
setSpringSecurityFitlerChain(context.getBean(securityBeanId, Filter.class)); setSpringSecurityFitlerChain(context.getBean(securityBeanId, Filter.class));
} }
Assert.state(getSpringSecurityFilterChain() != null,
if (getSpringSecurityFilterChain() == null) { () -> "springSecurityFilterChain cannot be null. Ensure a Bean with the name " + securityBeanId
throw new IllegalStateException("springSecurityFilterChain cannot be null. Ensure a Bean with the name " + " implementing Filter is present or inject the Filter to be used.");
+ securityBeanId + " implementing Filter is present or inject the Filter to be used.");
}
// This is used by other test support to obtain the FilterChainProxy // This is used by other test support to obtain the FilterChainProxy
context.getServletContext().setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN, getSpringSecurityFilterChain()); context.getServletContext().setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN, getSpringSecurityFilterChain());
return testSecurityContext(); return testSecurityContext();
} }
@ -118,11 +115,9 @@ final class SecurityMockMvcConfigurer extends MockMvcConfigurerAdapter {
Filter getDelegate() { Filter getDelegate() {
Filter result = this.delegate; Filter result = this.delegate;
if (result == null) { Assert.state(result != null,
throw new IllegalStateException( () -> "delegate cannot be null. Ensure a Bean with the name " + BeanIds.SPRING_SECURITY_FILTER_CHAIN
"delegate cannot be null. Ensure a Bean with the name " + BeanIds.SPRING_SECURITY_FILTER_CHAIN + " implementing Filter is present or inject the Filter to be used.");
+ " implementing Filter is present or inject the Filter to be used.");
}
return result; return result;
} }

23
test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java

@ -47,6 +47,9 @@ public abstract class WebTestUtils {
private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository(); private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository();
private WebTestUtils() {
}
/** /**
* Gets the {@link SecurityContextRepository} for the specified * Gets the {@link SecurityContextRepository} for the specified
* {@link HttpServletRequest}. If one is not found, a default * {@link HttpServletRequest}. If one is not found, a default
@ -134,18 +137,16 @@ public abstract class WebTestUtils {
} }
WebApplicationContext webApplicationContext = WebApplicationContextUtils WebApplicationContext webApplicationContext = WebApplicationContextUtils
.getWebApplicationContext(servletContext); .getWebApplicationContext(servletContext);
if (webApplicationContext != null) { if (webApplicationContext == null) {
try { return null;
return webApplicationContext.getBean(AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME, }
Filter.class); try {
} String beanName = AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME;
catch (NoSuchBeanDefinitionException notFound) { return webApplicationContext.getBean(beanName, Filter.class);
} }
catch (NoSuchBeanDefinitionException ex) {
return null;
} }
return null;
}
private WebTestUtils() {
} }
} }

Loading…
Cancel
Save