diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/AbstractRouterFunctionIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/AbstractRouterFunctionIntegrationTests.java index b26ea2ed9eb..9a21fdcf381 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/AbstractRouterFunctionIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/AbstractRouterFunctionIntegrationTests.java @@ -27,9 +27,13 @@ public abstract class AbstractRouterFunctionIntegrationTests extends AbstractHtt @Override protected final HttpHandler createHttpHandler() { RouterFunction routerFunction = routerFunction(); - return RouterFunctions.toHttpHandler(routerFunction); + return RouterFunctions.toHttpHandler(routerFunction, handlerStrategies()); } protected abstract RouterFunction routerFunction(); + protected HandlerStrategies handlerStrategies() { + return HandlerStrategies.withDefaults(); + } + } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RenderingResponseIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RenderingResponseIntegrationTests.java new file mode 100644 index 00000000000..8f6dc7255c3 --- /dev/null +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RenderingResponseIntegrationTests.java @@ -0,0 +1,166 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.reactive.function.server; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import org.junit.Test; +import reactor.core.publisher.Mono; + +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.http.server.reactive.ServerHttpResponse; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.reactive.result.view.View; +import org.springframework.web.reactive.result.view.ViewResolver; +import org.springframework.web.server.ServerWebExchange; + +import static org.junit.Assert.*; +import static org.springframework.web.reactive.function.server.HandlerFilterFunction.ofResponseProcessor; +import static org.springframework.web.reactive.function.server.RequestPredicates.GET; +import static org.springframework.web.reactive.function.server.RouterFunctions.route; + +/** + * @author Arjen Poutsma + * @since 5.0 + */ +public class RenderingResponseIntegrationTests extends AbstractRouterFunctionIntegrationTests { + + private final RestTemplate restTemplate = new RestTemplate(); + + + @Override + protected RouterFunction routerFunction() { + RenderingResponseHandler handler = new RenderingResponseHandler(); + RouterFunction normalRoute = route(GET("/normal"), handler::render); + RouterFunction filteredRoute = route(GET("/filter"), handler::render) + .filter(ofResponseProcessor( + response -> { + Map model = new LinkedHashMap<>(response.model()); + model.put("qux", "quux"); + + return RenderingResponse.create(response.name()) + .modelAttributes(model) + .build(); + })); + + return normalRoute.and(filteredRoute); + } + + @Override + protected HandlerStrategies handlerStrategies() { + return HandlerStrategies.builder() + .viewResolver(new DummyViewResolver()) + .build(); + + } + + @Test + public void normal() throws Exception { + ResponseEntity result = + restTemplate.getForEntity("http://localhost:" + port + "/normal", String.class); + + assertEquals(HttpStatus.OK, result.getStatusCode()); + Map body = parseBody(result.getBody()); + assertEquals(2, body.size()); + assertEquals("foo", body.get("name")); + assertEquals("baz", body.get("bar")); + } + + @Test + public void filter() throws Exception { + ResponseEntity result = + restTemplate.getForEntity("http://localhost:" + port + "/filter", String.class); + + assertEquals(HttpStatus.OK, result.getStatusCode()); + Map body = parseBody(result.getBody()); + assertEquals(3, body.size()); + assertEquals("foo", body.get("name")); + assertEquals("baz", body.get("bar")); + assertEquals("quux", body.get("qux")); + } + + private Map parseBody(String body) { + String[] lines = body.split("\\n"); + Map result = new LinkedHashMap<>(lines.length); + for (String line : lines) { + int idx = line.indexOf('='); + String key = line.substring(0, idx); + String value = line.substring(idx + 1); + result.put(key, value); + } + return result; + } + + private static class RenderingResponseHandler { + + public Mono render(ServerRequest request) { + return RenderingResponse.create("foo") + .modelAttribute("bar", "baz") + .build(); + } + + } + + private static class DummyViewResolver implements ViewResolver { + + @Override + public Mono resolveViewName(String viewName, Locale locale) { + return Mono.just(new DummyView(viewName)); + } + } + + + private static class DummyView implements View { + + private final String name; + + public DummyView(String name) { + this.name = name; + } + + @Override + public List getSupportedMediaTypes() { + return Collections.emptyList(); + } + + @Override + public Mono render(Map model, MediaType contentType, + ServerWebExchange exchange) { + StringBuilder builder = new StringBuilder(); + builder.append("name=").append(this.name).append('\n'); + for (Map.Entry entry : model.entrySet()) { + builder.append(entry.getKey()).append('=').append(entry.getValue()).append('\n'); + } + builder.setLength(builder.length() - 1); + byte[] bytes = builder.toString().getBytes(StandardCharsets.UTF_8); + + ServerHttpResponse response = exchange.getResponse(); + DataBuffer buffer = response.bufferFactory().wrap(bytes); + response.getHeaders().setContentType(MediaType.TEXT_PLAIN); + return response.writeWith(Mono.just(buffer)); + } + } + +}