4 changed files with 323 additions and 0 deletions
@ -0,0 +1,132 @@
@@ -0,0 +1,132 @@
|
||||
/* |
||||
* Copyright 2020 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package org.springframework.security.oauth2.server.authorization.web; |
||||
|
||||
import com.nimbusds.jose.JWSAlgorithm; |
||||
import com.nimbusds.jose.jwk.Curve; |
||||
import com.nimbusds.jose.jwk.ECKey; |
||||
import com.nimbusds.jose.jwk.JWK; |
||||
import com.nimbusds.jose.jwk.JWKSet; |
||||
import com.nimbusds.jose.jwk.KeyUse; |
||||
import com.nimbusds.jose.jwk.RSAKey; |
||||
import org.springframework.http.HttpMethod; |
||||
import org.springframework.http.MediaType; |
||||
import org.springframework.security.crypto.keys.KeyManager; |
||||
import org.springframework.security.crypto.keys.ManagedKey; |
||||
import org.springframework.security.web.util.matcher.AntPathRequestMatcher; |
||||
import org.springframework.security.web.util.matcher.RequestMatcher; |
||||
import org.springframework.util.Assert; |
||||
import org.springframework.web.filter.OncePerRequestFilter; |
||||
|
||||
import javax.servlet.FilterChain; |
||||
import javax.servlet.ServletException; |
||||
import javax.servlet.http.HttpServletRequest; |
||||
import javax.servlet.http.HttpServletResponse; |
||||
import java.io.IOException; |
||||
import java.io.Writer; |
||||
import java.security.interfaces.ECPublicKey; |
||||
import java.security.interfaces.RSAPublicKey; |
||||
import java.util.Objects; |
||||
import java.util.stream.Collectors; |
||||
|
||||
/** |
||||
* A {@code Filter} that processes JWK Set requests. |
||||
* |
||||
* @author Joe Grandja |
||||
* @since 0.0.1 |
||||
* @see KeyManager |
||||
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7517">JSON Web Key (JWK)</a> |
||||
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7517#section-5">Section 5 JWK Set Format</a> |
||||
*/ |
||||
public class JwkSetEndpointFilter extends OncePerRequestFilter { |
||||
/** |
||||
* The default endpoint {@code URI} for JWK Set requests. |
||||
*/ |
||||
public static final String DEFAULT_JWK_SET_ENDPOINT_URI = "/oauth2/jwks"; |
||||
|
||||
private final KeyManager keyManager; |
||||
private final RequestMatcher requestMatcher; |
||||
|
||||
/** |
||||
* Constructs a {@code JwkSetEndpointFilter} using the provided parameters. |
||||
* |
||||
* @param keyManager the key manager |
||||
*/ |
||||
public JwkSetEndpointFilter(KeyManager keyManager) { |
||||
this(keyManager, DEFAULT_JWK_SET_ENDPOINT_URI); |
||||
} |
||||
|
||||
/** |
||||
* Constructs a {@code JwkSetEndpointFilter} using the provided parameters. |
||||
* |
||||
* @param keyManager the key manager |
||||
* @param jwkSetEndpointUri the endpoint {@code URI} for JWK Set requests |
||||
*/ |
||||
public JwkSetEndpointFilter(KeyManager keyManager, String jwkSetEndpointUri) { |
||||
Assert.notNull(keyManager, "keyManager cannot be null"); |
||||
Assert.hasText(jwkSetEndpointUri, "jwkSetEndpointUri cannot be empty"); |
||||
this.keyManager = keyManager; |
||||
this.requestMatcher = new AntPathRequestMatcher(jwkSetEndpointUri, HttpMethod.GET.name()); |
||||
} |
||||
|
||||
@Override |
||||
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) |
||||
throws ServletException, IOException { |
||||
|
||||
if (!this.requestMatcher.matches(request)) { |
||||
filterChain.doFilter(request, response); |
||||
return; |
||||
} |
||||
|
||||
JWKSet jwkSet = buildJwkSet(); |
||||
|
||||
response.setContentType(MediaType.APPLICATION_JSON_VALUE); |
||||
try (Writer writer = response.getWriter()) { |
||||
writer.write(jwkSet.toJSONObject().toString()); |
||||
} |
||||
} |
||||
|
||||
private JWKSet buildJwkSet() { |
||||
return new JWKSet( |
||||
this.keyManager.getKeys().stream() |
||||
.filter(managedKey -> managedKey.isActive() && managedKey.isAsymmetric()) |
||||
.map(this::convert) |
||||
.filter(Objects::nonNull) |
||||
.collect(Collectors.toList()) |
||||
); |
||||
} |
||||
|
||||
private JWK convert(ManagedKey managedKey) { |
||||
JWK jwk = null; |
||||
if (managedKey.getPublicKey() instanceof RSAPublicKey) { |
||||
RSAPublicKey publicKey = (RSAPublicKey) managedKey.getPublicKey(); |
||||
jwk = new RSAKey.Builder(publicKey) |
||||
.keyUse(KeyUse.SIGNATURE) |
||||
.algorithm(JWSAlgorithm.RS256) |
||||
.keyID(managedKey.getKeyId()) |
||||
.build(); |
||||
} else if (managedKey.getPublicKey() instanceof ECPublicKey) { |
||||
ECPublicKey publicKey = (ECPublicKey) managedKey.getPublicKey(); |
||||
Curve curve = Curve.forECParameterSpec(publicKey.getParams()); |
||||
jwk = new ECKey.Builder(curve, publicKey) |
||||
.keyUse(KeyUse.SIGNATURE) |
||||
.algorithm(JWSAlgorithm.ES256) |
||||
.keyID(managedKey.getKeyId()) |
||||
.build(); |
||||
} |
||||
return jwk; |
||||
} |
||||
} |
||||
@ -0,0 +1,186 @@
@@ -0,0 +1,186 @@
|
||||
/* |
||||
* Copyright 2020 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package org.springframework.security.oauth2.server.authorization.web; |
||||
|
||||
import com.nimbusds.jose.JWSAlgorithm; |
||||
import com.nimbusds.jose.jwk.ECKey; |
||||
import com.nimbusds.jose.jwk.JWKSet; |
||||
import com.nimbusds.jose.jwk.KeyUse; |
||||
import com.nimbusds.jose.jwk.RSAKey; |
||||
import org.junit.Before; |
||||
import org.junit.Test; |
||||
import org.springframework.http.MediaType; |
||||
import org.springframework.mock.web.MockHttpServletRequest; |
||||
import org.springframework.mock.web.MockHttpServletResponse; |
||||
import org.springframework.security.crypto.keys.KeyManager; |
||||
import org.springframework.security.crypto.keys.ManagedKey; |
||||
import org.springframework.security.crypto.keys.TestManagedKeys; |
||||
|
||||
import javax.servlet.FilterChain; |
||||
import javax.servlet.http.HttpServletRequest; |
||||
import javax.servlet.http.HttpServletResponse; |
||||
import java.time.Instant; |
||||
import java.util.Collections; |
||||
import java.util.HashSet; |
||||
import java.util.stream.Collectors; |
||||
import java.util.stream.Stream; |
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat; |
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy; |
||||
import static org.mockito.ArgumentMatchers.any; |
||||
import static org.mockito.Mockito.mock; |
||||
import static org.mockito.Mockito.verify; |
||||
import static org.mockito.Mockito.verifyNoInteractions; |
||||
import static org.mockito.Mockito.when; |
||||
|
||||
/** |
||||
* Tests for {@link JwkSetEndpointFilter}. |
||||
* |
||||
* @author Joe Grandja |
||||
*/ |
||||
public class JwkSetEndpointFilterTests { |
||||
private KeyManager keyManager; |
||||
private JwkSetEndpointFilter filter; |
||||
|
||||
@Before |
||||
public void setUp() { |
||||
this.keyManager = mock(KeyManager.class); |
||||
this.filter = new JwkSetEndpointFilter(this.keyManager); |
||||
} |
||||
|
||||
@Test |
||||
public void constructorWhenKeyManagerNullThenThrowIllegalArgumentException() { |
||||
assertThatThrownBy(() -> new JwkSetEndpointFilter(null)) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("keyManager cannot be null"); |
||||
} |
||||
|
||||
@Test |
||||
public void constructorWhenJwkSetEndpointUriNullThenThrowIllegalArgumentException() { |
||||
assertThatThrownBy(() -> new JwkSetEndpointFilter(this.keyManager, null)) |
||||
.isInstanceOf(IllegalArgumentException.class) |
||||
.hasMessage("jwkSetEndpointUri cannot be empty"); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenNotJwkSetRequestThenNotProcessed() throws Exception { |
||||
String requestUri = "/path"; |
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); |
||||
request.setServletPath(requestUri); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenJwkSetRequestPostThenNotProcessed() throws Exception { |
||||
String requestUri = JwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI; |
||||
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); |
||||
request.setServletPath(requestUri); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenAsymmetricKeysThenJwkSetResponse() throws Exception { |
||||
ManagedKey rsaManagedKey = TestManagedKeys.rsaManagedKey().build(); |
||||
ManagedKey ecManagedKey = TestManagedKeys.ecManagedKey().build(); |
||||
when(this.keyManager.getKeys()).thenReturn( |
||||
Stream.of(rsaManagedKey, ecManagedKey).collect(Collectors.toSet())); |
||||
|
||||
String requestUri = JwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI; |
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); |
||||
request.setServletPath(requestUri); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getContentType()).isEqualTo(MediaType.APPLICATION_JSON_VALUE); |
||||
|
||||
JWKSet jwkSet = JWKSet.parse(response.getContentAsString()); |
||||
assertThat(jwkSet.getKeys()).hasSize(2); |
||||
|
||||
RSAKey rsaJwk = (RSAKey) jwkSet.getKeyByKeyId(rsaManagedKey.getKeyId()); |
||||
assertThat(rsaJwk).isNotNull(); |
||||
assertThat(rsaJwk.toRSAPublicKey()).isEqualTo(rsaManagedKey.getPublicKey()); |
||||
assertThat(rsaJwk.toRSAPrivateKey()).isNull(); |
||||
assertThat(rsaJwk.getKeyUse()).isEqualTo(KeyUse.SIGNATURE); |
||||
assertThat(rsaJwk.getAlgorithm()).isEqualTo(JWSAlgorithm.RS256); |
||||
|
||||
ECKey ecJwk = (ECKey) jwkSet.getKeyByKeyId(ecManagedKey.getKeyId()); |
||||
assertThat(ecJwk).isNotNull(); |
||||
assertThat(ecJwk.toECPublicKey()).isEqualTo(ecManagedKey.getPublicKey()); |
||||
assertThat(ecJwk.toECPublicKey()).isEqualTo(ecManagedKey.getPublicKey()); |
||||
assertThat(ecJwk.toECPrivateKey()).isNull(); |
||||
assertThat(ecJwk.getKeyUse()).isEqualTo(KeyUse.SIGNATURE); |
||||
assertThat(ecJwk.getAlgorithm()).isEqualTo(JWSAlgorithm.ES256); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenSymmetricKeysThenJwkSetResponseEmpty() throws Exception { |
||||
ManagedKey secretManagedKey = TestManagedKeys.secretManagedKey().build(); |
||||
when(this.keyManager.getKeys()).thenReturn( |
||||
new HashSet<>(Collections.singleton(secretManagedKey))); |
||||
|
||||
String requestUri = JwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI; |
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); |
||||
request.setServletPath(requestUri); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getContentType()).isEqualTo(MediaType.APPLICATION_JSON_VALUE); |
||||
|
||||
JWKSet jwkSet = JWKSet.parse(response.getContentAsString()); |
||||
assertThat(jwkSet.getKeys()).isEmpty(); |
||||
} |
||||
|
||||
@Test |
||||
public void doFilterWhenNoActiveKeysThenJwkSetResponseEmpty() throws Exception { |
||||
ManagedKey rsaManagedKey = TestManagedKeys.rsaManagedKey().deactivatedOn(Instant.now()).build(); |
||||
ManagedKey ecManagedKey = TestManagedKeys.ecManagedKey().deactivatedOn(Instant.now()).build(); |
||||
when(this.keyManager.getKeys()).thenReturn( |
||||
Stream.of(rsaManagedKey, ecManagedKey).collect(Collectors.toSet())); |
||||
|
||||
String requestUri = JwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI; |
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); |
||||
request.setServletPath(requestUri); |
||||
MockHttpServletResponse response = new MockHttpServletResponse(); |
||||
FilterChain filterChain = mock(FilterChain.class); |
||||
|
||||
this.filter.doFilter(request, response, filterChain); |
||||
|
||||
verifyNoInteractions(filterChain); |
||||
|
||||
assertThat(response.getContentType()).isEqualTo(MediaType.APPLICATION_JSON_VALUE); |
||||
|
||||
JWKSet jwkSet = JWKSet.parse(response.getContentAsString()); |
||||
assertThat(jwkSet.getKeys()).isEmpty(); |
||||
} |
||||
} |
||||
Loading…
Reference in new issue