diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java index 88009d80590..7074a3e2521 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java @@ -18,12 +18,13 @@ package org.springframework.http.server.reactive; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; +import java.util.stream.Collectors; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.core.io.buffer.DataBuffer; @@ -45,34 +46,26 @@ import org.springframework.util.MultiValueMap; */ public abstract class AbstractServerHttpResponse implements ServerHttpResponse { - private static final int STATE_NEW = 1; - - private static final int STATE_COMMITTING = 2; - - private static final int STATE_COMMITTED = 3; - - private final Log logger = LogFactory.getLog(getClass()); private final DataBufferFactory dataBufferFactory; + private HttpStatus statusCode; + private final HttpHeaders headers; private final MultiValueMap cookies; - private final List>> beforeCommitActions = new ArrayList<>(4); + private volatile boolean committed; - private final AtomicInteger state = new AtomicInteger(STATE_NEW); - - private HttpStatus statusCode; + private final List>> commitActions = new ArrayList<>(4); public AbstractServerHttpResponse(DataBufferFactory dataBufferFactory) { Assert.notNull(dataBufferFactory, "'dataBufferFactory' must not be null"); - this.dataBufferFactory = dataBufferFactory; this.headers = new HttpHeaders(); - this.cookies = new LinkedMultiValueMap(); + this.cookies = new LinkedMultiValueMap<>(); } @@ -84,15 +77,17 @@ public abstract class AbstractServerHttpResponse implements ServerHttpResponse { @Override public boolean setStatusCode(HttpStatus statusCode) { Assert.notNull(statusCode); - if (STATE_NEW == this.state.get()) { + if (this.committed) { + if (logger.isDebugEnabled()) { + logger.debug("Can't set the status " + statusCode.toString() + + " because the HTTP response has already been committed"); + } + return false; + } + else { this.statusCode = statusCode; return true; } - else if (logger.isDebugEnabled()) { - logger.debug("Can't set the status " + statusCode.toString() + - " because the HTTP response has already been committed"); - } - return false; } @Override @@ -102,64 +97,78 @@ public abstract class AbstractServerHttpResponse implements ServerHttpResponse { @Override public HttpHeaders getHeaders() { - if (STATE_COMMITTED == this.state.get()) { - return HttpHeaders.readOnlyHttpHeaders(this.headers); - } - else { - return this.headers; - } + return (this.committed ? HttpHeaders.readOnlyHttpHeaders(this.headers) : this.headers); } @Override public MultiValueMap getCookies() { - if (STATE_COMMITTED == this.state.get()) { - return CollectionUtils.unmodifiableMultiValueMap(this.cookies); - } - return this.cookies; + return (this.committed ? CollectionUtils.unmodifiableMultiValueMap(this.cookies) : this.cookies); } @Override public void beforeCommit(Supplier> action) { - Assert.notNull(action); - this.beforeCommitActions.add(action); + if (action != null) { + this.commitActions.add(action); + } } @Override public final Mono writeWith(Publisher body) { - return new ChannelSendOperator<>(body, writePublisher -> applyBeforeCommit() - .then(() -> writeWithInternal(writePublisher))); + return new ChannelSendOperator<>(body, + writePublisher -> doCommit(() -> writeWithInternal(writePublisher))); } @Override public final Mono writeAndFlushWith(Publisher> body) { - return new ChannelSendOperator<>(body, writePublisher -> applyBeforeCommit() - .then(() -> writeAndFlushWithInternal(writePublisher))); + return new ChannelSendOperator<>(body, + writePublisher -> doCommit(() -> writeAndFlushWithInternal(writePublisher))); } @Override public Mono setComplete() { - return applyBeforeCommit(); + return doCommit(); + } + + /** + * A variant of {@link #doCommit(Supplier)} for a response without no body. + * @return a completion publisher + */ + protected Mono doCommit() { + return doCommit(null); } - protected Mono applyBeforeCommit() { - Mono mono = Mono.empty(); - if (this.state.compareAndSet(STATE_NEW, STATE_COMMITTING)) { - for (Supplier> action : this.beforeCommitActions) { - mono = mono.then(action); + /** + * Apply {@link #beforeCommit(Supplier) beforeCommit} actions, apply the + * response status and headers/cookies, and write the response body. + * @param writeAction the action to write the response body or {@code null} + * @return a completion publisher + */ + protected Mono doCommit(Supplier> writeAction) { + if (this.committed) { + if (logger.isDebugEnabled()) { + logger.debug("Can't set the status " + statusCode.toString() + + " because the HTTP response has already been committed"); } - mono = mono.otherwise(ex -> { - // Ignore errors from beforeCommit actions - return Mono.empty(); - }); - mono = mono.then(() -> { - this.state.set(STATE_COMMITTED); - writeStatusCode(); - writeHeaders(); - writeCookies(); - return Mono.empty(); - }); + return Mono.empty(); } - return mono; + + this.committed = true; + + this.commitActions.add(() -> { + applyStatusCode(); + applyHeaders(); + applyCookies(); + return Mono.empty(); + }); + + if (writeAction != null) { + this.commitActions.add(writeAction); + } + + List> actions = this.commitActions.stream() + .map(Supplier::get).collect(Collectors.toList()); + + return Flux.concat(actions).next(); } @@ -180,18 +189,18 @@ public abstract class AbstractServerHttpResponse implements ServerHttpResponse { * Implement this method to write the status code to the underlying response. * This method is called once only. */ - protected abstract void writeStatusCode(); + protected abstract void applyStatusCode(); /** * Implement this method to apply header changes from {@link #getHeaders()} * to the underlying response. This method is called once only. */ - protected abstract void writeHeaders(); + protected abstract void applyHeaders(); /** * Implement this method to add cookies from {@link #getHeaders()} to the * underlying response. This method is called once only. */ - protected abstract void writeCookies(); + protected abstract void applyCookies(); } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java index dbbe17eaea6..45dfc7f1e65 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ReactorServerHttpResponse.java @@ -62,7 +62,7 @@ public class ReactorServerHttpResponse extends AbstractServerHttpResponse @Override - protected void writeStatusCode() { + protected void applyStatusCode() { HttpStatus statusCode = this.getStatusCode(); if (statusCode != null) { getReactorChannel().status(HttpResponseStatus.valueOf(statusCode.value())); @@ -83,7 +83,7 @@ public class ReactorServerHttpResponse extends AbstractServerHttpResponse } @Override - protected void writeHeaders() { + protected void applyHeaders() { // TODO: temporarily, see https://github.com/reactor/reactor-netty/issues/2 if(getHeaders().containsKey(HttpHeaders.CONTENT_LENGTH)){ this.channel.responseTransfer(false); @@ -96,7 +96,7 @@ public class ReactorServerHttpResponse extends AbstractServerHttpResponse } @Override - protected void writeCookies() { + protected void applyCookies() { for (String name : getCookies().keySet()) { for (ResponseCookie httpCookie : getCookies().get(name)) { Cookie cookie = new DefaultCookie(name, httpCookie.getValue()); @@ -114,12 +114,11 @@ public class ReactorServerHttpResponse extends AbstractServerHttpResponse @Override public Mono writeWith(File file, long position, long count) { - return applyBeforeCommit().then(() -> this.channel.sendFile(file, position, count)); + return doCommit(() -> this.channel.sendFile(file, position, count)); } private static Publisher toByteBufs(Publisher dataBuffers) { - return Flux.from(dataBuffers). - map(NettyDataBufferFactory::toByteBuf); + return Flux.from(dataBuffers).map(NettyDataBufferFactory::toByteBuf); } diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/RxNettyServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/RxNettyServerHttpResponse.java index 6ce89768a6f..e0c04346ee9 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/RxNettyServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/RxNettyServerHttpResponse.java @@ -63,7 +63,7 @@ public class RxNettyServerHttpResponse extends AbstractServerHttpResponse { @Override - protected void writeStatusCode() { + protected void applyStatusCode() { HttpStatus statusCode = this.getStatusCode(); if (statusCode != null) { this.response.setStatus(HttpResponseStatus.valueOf(statusCode.value())); @@ -91,7 +91,7 @@ public class RxNettyServerHttpResponse extends AbstractServerHttpResponse { } @Override - protected void writeHeaders() { + protected void applyHeaders() { for (String name : getHeaders().keySet()) { for (String value : getHeaders().get(name)) { this.response.addHeader(name, value); @@ -100,7 +100,7 @@ public class RxNettyServerHttpResponse extends AbstractServerHttpResponse { } @Override - protected void writeCookies() { + protected void applyCookies() { for (String name : getCookies().keySet()) { for (ResponseCookie httpCookie : getCookies().get(name)) { Cookie cookie = new DefaultCookie(name, httpCookie.getValue()); diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java index 01b7b55ba08..1755d2ab1a9 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpResponse.java @@ -72,7 +72,7 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons } @Override - protected void writeStatusCode() { + protected void applyStatusCode() { HttpStatus statusCode = this.getStatusCode(); if (statusCode != null) { getServletResponse().setStatus(statusCode.value()); @@ -80,7 +80,7 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons } @Override - protected void writeHeaders() { + protected void applyHeaders() { for (Map.Entry> entry : getHeaders().entrySet()) { String headerName = entry.getKey(); for (String headerValue : entry.getValue()) { @@ -98,7 +98,7 @@ public class ServletServerHttpResponse extends AbstractListenerServerHttpRespons } @Override - protected void writeCookies() { + protected void applyCookies() { for (String name : getCookies().keySet()) { for (ResponseCookie httpCookie : getCookies().get(name)) { Cookie cookie = new Cookie(name, httpCookie.getValue()); diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java index 30102b6e2d5..b07b8cdad1e 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/UndertowServerHttpResponse.java @@ -68,7 +68,7 @@ public class UndertowServerHttpResponse extends AbstractListenerServerHttpRespon } @Override - protected void writeStatusCode() { + protected void applyStatusCode() { HttpStatus statusCode = this.getStatusCode(); if (statusCode != null) { getUndertowExchange().setStatusCode(statusCode.value()); @@ -77,8 +77,8 @@ public class UndertowServerHttpResponse extends AbstractListenerServerHttpRespon @Override public Mono writeWith(File file, long position, long count) { - writeHeaders(); - writeCookies(); + applyHeaders(); + applyCookies(); try { StreamSinkChannel responseChannel = getUndertowExchange().getResponseChannel(); @SuppressWarnings("resource") @@ -98,7 +98,7 @@ public class UndertowServerHttpResponse extends AbstractListenerServerHttpRespon } @Override - protected void writeHeaders() { + protected void applyHeaders() { for (Map.Entry> entry : getHeaders().entrySet()) { HttpString headerName = HttpString.tryFromString(entry.getKey()); this.exchange.getResponseHeaders().addAll(headerName, entry.getValue()); @@ -106,7 +106,7 @@ public class UndertowServerHttpResponse extends AbstractListenerServerHttpRespon } @Override - protected void writeCookies() { + protected void applyCookies() { for (String name : getCookies().keySet()) { for (ResponseCookie httpCookie : getCookies().get(name)) { Cookie cookie = new CookieImpl(name, httpCookie.getValue()); diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java index e143b67e512..34b8b03b237 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java @@ -158,19 +158,19 @@ public class ServerHttpResponseTests { } @Override - public void writeStatusCode() { + public void applyStatusCode() { assertFalse(this.statusCodeWritten); this.statusCodeWritten = true; } @Override - protected void writeHeaders() { + protected void applyHeaders() { assertFalse(this.headersWritten); this.headersWritten = true; } @Override - protected void writeCookies() { + protected void applyCookies() { assertFalse(this.cookiesWritten); this.cookiesWritten = true; }