diff --git a/spring-boot/src/main/java/org/springframework/boot/web/embedded/undertow/UndertowServletWebServerFactory.java b/spring-boot/src/main/java/org/springframework/boot/web/embedded/undertow/UndertowServletWebServerFactory.java index 636738c6dda..16b706284aa 100644 --- a/spring-boot/src/main/java/org/springframework/boot/web/embedded/undertow/UndertowServletWebServerFactory.java +++ b/spring-boot/src/main/java/org/springframework/boot/web/embedded/undertow/UndertowServletWebServerFactory.java @@ -19,12 +19,16 @@ package org.springframework.boot.web.embedded.undertow; import java.io.File; import java.io.IOException; import java.net.MalformedURLException; +import java.net.Socket; import java.net.URL; import java.net.URLConnection; import java.nio.charset.Charset; import java.security.KeyManagementException; import java.security.KeyStore; import java.security.NoSuchAlgorithmException; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -37,8 +41,10 @@ import java.util.Set; import javax.net.ssl.KeyManager; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedKeyManager; import javax.servlet.ServletContainerInitializer; import javax.servlet.ServletContext; import javax.servlet.ServletException; @@ -309,13 +315,23 @@ public class UndertowServletWebServerFactory extends AbstractServletWebServerFac keyPassword = ssl.getKeyStorePassword().toCharArray(); } keyManagerFactory.init(keyStore, keyPassword); - return keyManagerFactory.getKeyManagers(); + return getConfigurableAliasKeyManagers(ssl, keyManagerFactory.getKeyManagers()); } catch (Exception ex) { throw new IllegalStateException(ex); } } + private KeyManager[] getConfigurableAliasKeyManagers(Ssl ssl, KeyManager[] keyManagers) { + for (int i = 0; i < keyManagers.length; i++) { + if (keyManagers[i] instanceof X509ExtendedKeyManager) { + keyManagers[i] = new ConfigurableAliasKeyManager((X509ExtendedKeyManager) keyManagers[i], + ssl.getKeyAlias()); + } + } + return keyManagers; + } + private KeyStore getKeyStore() throws Exception { if (getSslStoreProvider() != null) { return getSslStoreProvider().getKeyStore(); @@ -691,6 +707,57 @@ public class UndertowServletWebServerFactory extends AbstractServletWebServerFac initializer.onStartup(servletContext); } } + } + + private static class ConfigurableAliasKeyManager extends X509ExtendedKeyManager { + + private final X509ExtendedKeyManager sourceKeyManager; + + private final String alias; + + ConfigurableAliasKeyManager(X509ExtendedKeyManager keyManager, String alias) { + this.sourceKeyManager = keyManager; + this.alias = alias; + } + + @Override + public String chooseEngineClientAlias(String[] strings, Principal[] principals, SSLEngine sslEngine) { + return this.sourceKeyManager.chooseEngineClientAlias(strings, principals, sslEngine); + } + + @Override + public String chooseEngineServerAlias(String s, Principal[] principals, SSLEngine sslEngine) { + if (this.alias == null) { + return this.sourceKeyManager.chooseEngineServerAlias(s, principals, sslEngine); + } + return this.alias; + } + + public String chooseClientAlias(String[] keyType, Principal[] issuers, + Socket socket) { + return this.sourceKeyManager.chooseClientAlias(keyType, issuers, socket); + } + + public String chooseServerAlias(String keyType, Principal[] issuers, + Socket socket) { + return this.sourceKeyManager.chooseServerAlias(keyType, issuers, socket); + } + + public X509Certificate[] getCertificateChain(String alias) { + return this.sourceKeyManager.getCertificateChain(alias); + } + + public String[] getClientAliases(String keyType, Principal[] issuers) { + return this.sourceKeyManager.getClientAliases(keyType, issuers); + } + + public PrivateKey getPrivateKey(String alias) { + return this.sourceKeyManager.getPrivateKey(alias); + } + + public String[] getServerAliases(String keyType, Principal[] issuers) { + return this.sourceKeyManager.getServerAliases(keyType, issuers); + } } diff --git a/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java b/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java index 4a50a1906db..a358255acf1 100644 --- a/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/web/servlet/server/AbstractServletWebServerFactoryTests.java @@ -34,6 +34,7 @@ import java.security.KeyStore; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -446,6 +447,24 @@ public abstract class AbstractServletWebServerFactoryTests { .contains("scheme=https"); } + @Test + public void sslKeyAlias() throws Exception { + AbstractEmbeddedServletContainerFactory factory = getFactory(); + factory.setSsl(getSsl(null, "password", "test-alias", "src/test/resources/test.jks")); + this.container = factory.getEmbeddedServletContainer( + new ServletRegistrationBean(new ExampleServlet(true, false), "/hello")); + this.container.start(); + SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( + new SSLContextBuilder() + .loadTrustMaterial(null, new SerialNumberValidatingTrustSelfSignedStrategy("77e7c302")).build()); + HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory) + .build(); + HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory( + httpClient); + assertThat(getResponse(getLocalUrl("https", "/hello"), requestFactory)) + .contains("scheme=https"); + } + @Test public void serverHeaderIsDisabledByDefaultWhenUsingSsl() throws Exception { AbstractServletWebServerFactory factory = getFactory(); @@ -659,13 +678,25 @@ public abstract class AbstractServletWebServerFactoryTests { return getSsl(clientAuth, keyPassword, keyStore, null, null, null); } + private Ssl getSsl(ClientAuth clientAuth, String keyPassword, String keyAlias, String keyStore) { + return getSsl(clientAuth, keyPassword, keyAlias, keyStore, null, null, null); + } + private Ssl getSsl(ClientAuth clientAuth, String keyPassword, String keyStore, String trustStore, String[] supportedProtocols, String[] ciphers) { + return getSsl(clientAuth, keyPassword, null, keyStore, trustStore, supportedProtocols, ciphers); + } + + private Ssl getSsl(ClientAuth clientAuth, String keyPassword, String keyAlias, String keyStore, + String trustStore, String[] supportedProtocols, String[] ciphers) { Ssl ssl = new Ssl(); ssl.setClientAuth(clientAuth); if (keyPassword != null) { ssl.setKeyPassword(keyPassword); } + if (keyAlias != null) { + ssl.setKeyAlias(keyAlias); + } if (keyStore != null) { ssl.setKeyStore(keyStore); ssl.setKeyStorePassword("secret"); @@ -1255,4 +1286,25 @@ public abstract class AbstractServletWebServerFactoryTests { } + /** + * {@link TrustSelfSignedStrategy} that also validates certificate serial + * number. + */ + private static final class SerialNumberValidatingTrustSelfSignedStrategy extends TrustSelfSignedStrategy { + + private final String serialNumber; + + private SerialNumberValidatingTrustSelfSignedStrategy(String serialNumber) { + this.serialNumber = serialNumber; + } + + @Override + public boolean isTrusted(X509Certificate[] chain, String authType) throws CertificateException { + String hexSerialNumber = chain[0].getSerialNumber().toString(16); + boolean isMatch = hexSerialNumber.equals(this.serialNumber); + return super.isTrusted(chain, authType) && isMatch; + } + + } + } diff --git a/spring-boot/src/test/resources/test.jks b/spring-boot/src/test/resources/test.jks index b10103d0d9d..1bce90bba66 100644 Binary files a/spring-boot/src/test/resources/test.jks and b/spring-boot/src/test/resources/test.jks differ