diff --git a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/web/embedded/NettyWebServerFactoryCustomizer.java b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/web/embedded/NettyWebServerFactoryCustomizer.java index 1b4a99a55fa..3269ee08bc5 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/web/embedded/NettyWebServerFactoryCustomizer.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/web/embedded/NettyWebServerFactoryCustomizer.java @@ -24,7 +24,6 @@ import org.springframework.boot.autoconfigure.web.ServerProperties; import org.springframework.boot.cloud.CloudPlatform; import org.springframework.boot.context.properties.PropertyMapper; import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory; -import org.springframework.boot.web.embedded.netty.NettyServerCustomizer; import org.springframework.boot.web.server.WebServerFactoryCustomizer; import org.springframework.core.Ordered; import org.springframework.core.env.Environment; @@ -58,11 +57,11 @@ public class NettyWebServerFactoryCustomizer @Override public void customize(NettyReactiveWebServerFactory factory) { factory.setUseForwardHeaders(getOrDeduceUseForwardHeaders()); - PropertyMapper propertyMapper = PropertyMapper.get(); - propertyMapper.from(this.serverProperties::getMaxHttpHeaderSize).whenNonNull().asInt(DataSize::toBytes) + PropertyMapper propertyMapper = PropertyMapper.get().alwaysApplyingWhenNonNull(); + propertyMapper.from(this.serverProperties::getMaxHttpHeaderSize) .to((maxHttpRequestHeaderSize) -> customizeMaxHttpHeaderSize(factory, maxHttpRequestHeaderSize)); - propertyMapper.from(this.serverProperties::getConnectionTimeout).whenNonNull().asInt(Duration::toMillis) - .to((duration) -> factory.addServerCustomizers(getConnectionTimeOutCustomizer(duration))); + propertyMapper.from(this.serverProperties::getConnectionTimeout) + .to((connectionTimeout) -> customizeConnectionTimeout(factory, connectionTimeout)); } private boolean getOrDeduceUseForwardHeaders() { @@ -73,14 +72,17 @@ public class NettyWebServerFactoryCustomizer return this.serverProperties.getForwardHeadersStrategy().equals(ServerProperties.ForwardHeadersStrategy.NATIVE); } - private void customizeMaxHttpHeaderSize(NettyReactiveWebServerFactory factory, Integer maxHttpHeaderSize) { - factory.addServerCustomizers((NettyServerCustomizer) (httpServer) -> httpServer.httpRequestDecoder( - (httpRequestDecoderSpec) -> httpRequestDecoderSpec.maxHeaderSize(maxHttpHeaderSize))); + private void customizeMaxHttpHeaderSize(NettyReactiveWebServerFactory factory, DataSize maxHttpHeaderSize) { + factory.addServerCustomizers((httpServer) -> httpServer.httpRequestDecoder( + (httpRequestDecoderSpec) -> httpRequestDecoderSpec.maxHeaderSize((int) maxHttpHeaderSize.toBytes()))); } - private NettyServerCustomizer getConnectionTimeOutCustomizer(int duration) { - return (httpServer) -> httpServer.tcpConfiguration( - (tcpServer) -> tcpServer.selectorOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, duration)); + private void customizeConnectionTimeout(NettyReactiveWebServerFactory factory, Duration connectionTimeout) { + if (!connectionTimeout.isZero()) { + long timeoutMillis = connectionTimeout.isNegative() ? 0 : connectionTimeout.toMillis(); + factory.addServerCustomizers((httpServer) -> httpServer.tcpConfiguration((tcpServer) -> tcpServer + .selectorOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) timeoutMillis))); + } } } diff --git a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/web/embedded/NettyWebServerFactoryCustomizerTests.java b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/web/embedded/NettyWebServerFactoryCustomizerTests.java index 3d16d3b74b5..12ec09ddd8c 100644 --- a/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/web/embedded/NettyWebServerFactoryCustomizerTests.java +++ b/spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/web/embedded/NettyWebServerFactoryCustomizerTests.java @@ -16,21 +16,38 @@ package org.springframework.boot.autoconfigure.web.embedded; +import java.time.Duration; +import java.util.Map; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.ChannelOption; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.MockitoAnnotations; +import reactor.netty.http.server.HttpServer; +import reactor.netty.tcp.TcpServer; import org.springframework.boot.autoconfigure.web.ServerProperties; import org.springframework.boot.context.properties.source.ConfigurationPropertySources; import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory; +import org.springframework.boot.web.embedded.netty.NettyServerCustomizer; import org.springframework.mock.env.MockEnvironment; +import org.springframework.test.util.ReflectionTestUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; /** * Tests for {@link NettyWebServerFactoryCustomizer}. * * @author Brian Clozel + * @author Artsiom Yudovin */ class NettyWebServerFactoryCustomizerTests { @@ -40,8 +57,12 @@ class NettyWebServerFactoryCustomizerTests { private NettyWebServerFactoryCustomizer customizer; + @Captor + private ArgumentCaptor customizerCaptor; + @BeforeEach public void setup() { + MockitoAnnotations.initMocks(this); this.environment = new MockEnvironment(); this.serverProperties = new ServerProperties(); ConfigurationPropertySources.attach(this.environment); @@ -71,4 +92,49 @@ class NettyWebServerFactoryCustomizerTests { verify(factory).setUseForwardHeaders(true); } + @Test + void setConnectionTimeoutAsZero() { + setupConnectionTimeout(Duration.ZERO); + NettyReactiveWebServerFactory factory = mock(NettyReactiveWebServerFactory.class); + this.customizer.customize(factory); + verifyConnectionTimeout(factory, null); + } + + @Test + void setConnectionTimeoutAsMinusOne() { + setupConnectionTimeout(Duration.ofNanos(-1)); + NettyReactiveWebServerFactory factory = mock(NettyReactiveWebServerFactory.class); + this.customizer.customize(factory); + verifyConnectionTimeout(factory, 0); + } + + @Test + void setConnectionTimeout() { + setupConnectionTimeout(Duration.ofSeconds(1)); + NettyReactiveWebServerFactory factory = mock(NettyReactiveWebServerFactory.class); + this.customizer.customize(factory); + verifyConnectionTimeout(factory, 1000); + } + + @SuppressWarnings("unchecked") + private void verifyConnectionTimeout(NettyReactiveWebServerFactory factory, Integer expected) { + if (expected == null) { + verify(factory, never()).addServerCustomizers(any(NettyServerCustomizer.class)); + return; + } + verify(factory, times(1)).addServerCustomizers(this.customizerCaptor.capture()); + NettyServerCustomizer serverCustomizer = this.customizerCaptor.getValue(); + HttpServer httpServer = serverCustomizer.apply(HttpServer.create()); + TcpServer tcpConfiguration = ReflectionTestUtils.invokeMethod(httpServer, "tcpConfiguration"); + ServerBootstrap bootstrap = tcpConfiguration.configure(); + Map options = (Map) ReflectionTestUtils.getField(bootstrap, "options"); + assertThat(options).containsEntry(ChannelOption.CONNECT_TIMEOUT_MILLIS, expected); + } + + private void setupConnectionTimeout(Duration connectionTimeout) { + this.serverProperties.setUseForwardHeaders(null); + this.serverProperties.setMaxHttpHeaderSize(null); + this.serverProperties.setConnectionTimeout(connectionTimeout); + } + } diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/context/properties/PropertyMapper.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/context/properties/PropertyMapper.java index f5486abb291..72e8cd50ecd 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/context/properties/PropertyMapper.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/context/properties/PropertyMapper.java @@ -50,6 +50,7 @@ import org.springframework.util.StringUtils; * {@link Source#toInstance(Function) new instance}. * * @author Phillip Webb + * @author Artsiom Yudovin * @since 2.0.0 */ public final class PropertyMapper { @@ -288,7 +289,7 @@ public final class PropertyMapper { */ public Source whenNot(Predicate predicate) { Assert.notNull(predicate, "Predicate must not be null"); - return new Source<>(this.supplier, predicate.negate()); + return when(predicate.negate()); } /** @@ -299,7 +300,7 @@ public final class PropertyMapper { */ public Source when(Predicate predicate) { Assert.notNull(predicate, "Predicate must not be null"); - return new Source<>(this.supplier, predicate); + return new Source<>(this.supplier, (this.predicate != null) ? this.predicate.and(predicate) : predicate); } /** diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/context/properties/PropertyMapperTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/context/properties/PropertyMapperTests.java index 61be2605373..2ef680e29f0 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/context/properties/PropertyMapperTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/context/properties/PropertyMapperTests.java @@ -28,6 +28,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException * Tests for {@link PropertyMapper}. * * @author Phillip Webb + * @author Artsiom Yudovin */ class PropertyMapperTests { @@ -195,6 +196,17 @@ class PropertyMapperTests { this.map.alwaysApplyingWhenNonNull().from(() -> null).toCall(Assertions::fail); } + @Test + public void whenWhenValueNotMatchesShouldSupportChainedCalls() { + this.map.from("123").when("456"::equals).when("123"::equals).toCall(Assertions::fail); + } + + @Test + public void whenWhenValueMatchesShouldSupportChainedCalls() { + String result = this.map.from("123").when((s) -> s.contains("2")).when("123"::equals).toInstance(String::new); + assertThat(result).isEqualTo("123"); + } + private static class Count implements Supplier { private final Supplier source;