diff --git a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java index c3b95454873..c6ed003a4b6 100644 --- a/spring-web/src/main/java/org/springframework/http/HttpHeaders.java +++ b/spring-web/src/main/java/org/springframework/http/HttpHeaders.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -158,8 +158,8 @@ public class HttpHeaders implements MultiValueMap, Serializable List result = (value != null) ? MediaType.parseMediaTypes(value) : Collections.emptyList(); // Some containers parse 'Accept' into multiple values - if ((result.size() == 1) && (headers.get(ACCEPT).size() > 1)) { - value = StringUtils.collectionToCommaDelimitedString(headers.get(ACCEPT)); + if ((result.size() == 1) && (get(ACCEPT).size() > 1)) { + value = StringUtils.collectionToCommaDelimitedString(get(ACCEPT)); result = MediaType.parseMediaTypes(value); } diff --git a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java index b65d0cdcada..802a53dfde9 100644 --- a/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/ServletServerHttpResponse.java @@ -18,6 +18,8 @@ package org.springframework.http.server; import java.io.IOException; import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; @@ -26,18 +28,25 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; /** * {@link ServerHttpResponse} implementation that is based on a {@link HttpServletResponse}. * * @author Arjen Poutsma + * @author Rossen Stoyanchev * @since 3.0 */ public class ServletServerHttpResponse implements ServerHttpResponse { + private static final boolean servlet3Present = + ClassUtils.isPresent("javax.servlet.AsyncContext", ServletServerHttpResponse.class.getClassLoader()); + + private final HttpServletResponse servletResponse; - private final HttpHeaders headers = new HttpHeaders(); + private final HttpHeaders headers; private boolean headersWritten = false; @@ -49,6 +58,7 @@ public class ServletServerHttpResponse implements ServerHttpResponse { public ServletServerHttpResponse(HttpServletResponse servletResponse) { Assert.notNull(servletResponse, "'servletResponse' must not be null"); this.servletResponse = servletResponse; + this.headers = (servlet3Present ? new ServletResponseHttpHeaders() : new HttpHeaders()); } @@ -105,4 +115,56 @@ public class ServletServerHttpResponse implements ServerHttpResponse { this.headersWritten = true; } } + + /** + * Extends HttpHeaders with the ability to look up headers already present in + * the underlying HttpServletResponse. + * + * The intent is merely to expose what is available through the HttpServletResponse + * i.e. the ability to look up specific header values by name. All other + * map-related operations (e.g. iteration, removal, etc) apply only to values + * added directly through HttpHeaders methods. + * + * @since 4.0.3 + */ + private class ServletResponseHttpHeaders extends HttpHeaders { + + private static final long serialVersionUID = 3410708522401046302L; + + @Override + public String getFirst(String headerName) { + String value = servletResponse.getHeader(headerName); + if (value != null) { + return value; + } + else { + return super.getFirst(headerName); + } + } + + @Override + public List get(Object key) { + + Assert.isInstanceOf(String.class, key, "key must be a String-based header name"); + Collection values1 = servletResponse.getHeaders((String) key); + boolean isEmpty1 = CollectionUtils.isEmpty(values1); + + List values2 = super.get(key); + boolean isEmpty2 = CollectionUtils.isEmpty(values2); + + if (isEmpty1 && isEmpty2) { + return null; + } + + List values = new ArrayList(); + if (!isEmpty1) { + values.addAll(values1); + } + if (!isEmpty2) { + values.addAll(values2); + } + return values; + } + } + } diff --git a/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java b/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java index 0c329bd785c..ddcc376f585 100644 --- a/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/ServletServerHttpResponseTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2012 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.http.server; import java.nio.charset.Charset; +import java.util.Arrays; import java.util.List; import org.junit.Before; @@ -71,6 +72,19 @@ public class ServletServerHttpResponseTests { assertEquals("Invalid Content-Type", "UTF-8", mockResponse.getCharacterEncoding()); } + @Test + public void getHeadersFromHttpServletResponse() { + + String headerName = "Access-Control-Allow-Origin"; + String headerValue = "localhost:8080"; + + this.mockResponse.addHeader(headerName, headerValue); + this.response = new ServletServerHttpResponse(this.mockResponse); + + assertEquals(headerValue, this.response.getHeaders().getFirst(headerName)); + assertEquals(Arrays.asList(headerValue), this.response.getHeaders().get(headerName)); + } + @Test public void getBody() throws Exception { byte[] content = "Hello World".getBytes("UTF-8"); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java index cdf196c59cd..1baf7e88f4d 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ import java.util.concurrent.TimeUnit; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.InvalidMediaTypeException; @@ -352,22 +353,32 @@ public abstract class AbstractSockJsService implements SockJsService { protected void addCorsHeaders(ServerHttpRequest request, ServerHttpResponse response, HttpMethod... httpMethods) { - String origin = request.getHeaders().getFirst("origin"); + + HttpHeaders requestHeaders = request.getHeaders(); + HttpHeaders responseHeaders = response.getHeaders(); + + // Perhaps a CORS Filter has already added this? + if (!CollectionUtils.isEmpty(responseHeaders.get("Access-Control-Allow-Origin"))) { + logger.debug("Skip adding CORS headers, response already contains \"Access-Control-Allow-Origin\""); + return; + } + + String origin = requestHeaders.getFirst("origin"); origin = ((origin == null) || origin.equals("null")) ? "*" : origin; - response.getHeaders().add("Access-Control-Allow-Origin", origin); - response.getHeaders().add("Access-Control-Allow-Credentials", "true"); + responseHeaders.add("Access-Control-Allow-Origin", origin); + responseHeaders.add("Access-Control-Allow-Credentials", "true"); - List accessControllerHeaders = request.getHeaders().get("Access-Control-Request-Headers"); + List accessControllerHeaders = requestHeaders.get("Access-Control-Request-Headers"); if (accessControllerHeaders != null) { for (String header : accessControllerHeaders) { - response.getHeaders().add("Access-Control-Allow-Headers", header); + responseHeaders.add("Access-Control-Allow-Headers", header); } } if (!ObjectUtils.isEmpty(httpMethods)) { - response.getHeaders().add("Access-Control-Allow-Methods", StringUtils.arrayToDelimitedString(httpMethods, ", ")); - response.getHeaders().add("Access-Control-Max-Age", String.valueOf(ONE_YEAR)); + responseHeaders.add("Access-Control-Allow-Methods", StringUtils.arrayToDelimitedString(httpMethods, ", ")); + responseHeaders.add("Access-Control-Max-Age", String.valueOf(ONE_YEAR)); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java index aa6f95b5fe1..20c3eb19c3d 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/support/SockJsServiceTests.java @@ -54,25 +54,25 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { public void validateRequest() throws Exception { this.service.setWebSocketEnabled(false); - handleRequest("GET", "/echo/server/session/websocket", HttpStatus.NOT_FOUND); + resetResponseAndHandleRequest("GET", "/echo/server/session/websocket", HttpStatus.NOT_FOUND); this.service.setWebSocketEnabled(true); - handleRequest("GET", "/echo/server/session/websocket", HttpStatus.OK); - - handleRequest("GET", "/echo//", HttpStatus.NOT_FOUND); - handleRequest("GET", "/echo///", HttpStatus.NOT_FOUND); - handleRequest("GET", "/echo/other", HttpStatus.NOT_FOUND); - handleRequest("GET", "/echo//service/websocket", HttpStatus.NOT_FOUND); - handleRequest("GET", "/echo/server//websocket", HttpStatus.NOT_FOUND); - handleRequest("GET", "/echo/server/session/", HttpStatus.NOT_FOUND); - handleRequest("GET", "/echo/s.erver/session/websocket", HttpStatus.NOT_FOUND); - handleRequest("GET", "/echo/server/s.ession/websocket", HttpStatus.NOT_FOUND); + resetResponseAndHandleRequest("GET", "/echo/server/session/websocket", HttpStatus.OK); + + resetResponseAndHandleRequest("GET", "/echo//", HttpStatus.NOT_FOUND); + resetResponseAndHandleRequest("GET", "/echo///", HttpStatus.NOT_FOUND); + resetResponseAndHandleRequest("GET", "/echo/other", HttpStatus.NOT_FOUND); + resetResponseAndHandleRequest("GET", "/echo//service/websocket", HttpStatus.NOT_FOUND); + resetResponseAndHandleRequest("GET", "/echo/server//websocket", HttpStatus.NOT_FOUND); + resetResponseAndHandleRequest("GET", "/echo/server/session/", HttpStatus.NOT_FOUND); + resetResponseAndHandleRequest("GET", "/echo/s.erver/session/websocket", HttpStatus.NOT_FOUND); + resetResponseAndHandleRequest("GET", "/echo/server/s.ession/websocket", HttpStatus.NOT_FOUND); } @Test public void handleInfoGet() throws Exception { - handleRequest("GET", "/echo/info", HttpStatus.OK); + resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); assertEquals("application/json;charset=UTF-8", this.servletResponse.getContentType()); assertEquals("*", this.servletResponse.getHeader("Access-Control-Allow-Origin")); @@ -86,19 +86,32 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { this.service.setSessionCookieNeeded(false); this.service.setWebSocketEnabled(false); - handleRequest("GET", "/echo/info", HttpStatus.OK); + resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); body = this.servletResponse.getContentAsString(); assertEquals(",\"origins\":[\"*:*\"],\"cookie_needed\":false,\"websocket\":false}", body.substring(body.indexOf(','))); } + // SPR-11443 + + @Test + public void handleInfoGetCorsFilter() throws Exception { + + // Simulate scenario where Filter would have already set CORS headers + this.servletResponse.setHeader("Access-Control-Allow-Origin", "foobar:123"); + + handleRequest("GET", "/echo/info", HttpStatus.OK); + + assertEquals("foobar:123", this.servletResponse.getHeader("Access-Control-Allow-Origin")); + } + @Test public void handleInfoOptions() throws Exception { this.servletRequest.addHeader("Access-Control-Request-Headers", "Last-Modified"); - handleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT); + resetResponseAndHandleRequest("OPTIONS", "/echo/info", HttpStatus.NO_CONTENT); this.response.flush(); assertEquals("*", this.servletResponse.getHeader("Access-Control-Allow-Origin")); @@ -111,7 +124,7 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { @Test public void handleIframeRequest() throws Exception { - handleRequest("GET", "/echo/iframe.html", HttpStatus.OK); + resetResponseAndHandleRequest("GET", "/echo/iframe.html", HttpStatus.OK); assertEquals("text/html;charset=UTF-8", this.servletResponse.getContentType()); assertTrue(this.servletResponse.getContentAsString().startsWith("\n")); @@ -125,23 +138,35 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { this.servletRequest.addHeader("If-None-Match", "\"0da1ed070012f304e47b83c81c48ad620\""); - handleRequest("GET", "/echo/iframe.html", HttpStatus.NOT_MODIFIED); + resetResponseAndHandleRequest("GET", "/echo/iframe.html", HttpStatus.NOT_MODIFIED); } @Test public void handleRawWebSocketRequest() throws Exception { - handleRequest("GET", "/echo", HttpStatus.OK); + resetResponseAndHandleRequest("GET", "/echo", HttpStatus.OK); assertEquals("Welcome to SockJS!\n", this.servletResponse.getContentAsString()); - handleRequest("GET", "/echo/websocket", HttpStatus.OK); + resetResponseAndHandleRequest("GET", "/echo/websocket", HttpStatus.OK); assertNull("Raw WebSocket should not open a SockJS session", this.service.sessionId); assertSame(this.handler, this.service.handler); } + @Test + public void handleEmptyContentType() throws Exception { - private void handleRequest(String httpMethod, String uri, HttpStatus httpStatus) throws IOException { + servletRequest.setContentType(""); + resetResponseAndHandleRequest("GET", "/echo/info", HttpStatus.OK); + + assertEquals("Invalid/empty content should have been ignored", 200, this.servletResponse.getStatus()); + } + + private void resetResponseAndHandleRequest(String httpMethod, String uri, HttpStatus httpStatus) throws IOException { resetResponse(); + handleRequest(httpMethod, uri, httpStatus); + } + + private void handleRequest(String httpMethod, String uri, HttpStatus httpStatus) throws IOException { setRequest(httpMethod, uri); String sockJsPath = uri.substring("/echo".length()); this.service.handleRequest(this.request, this.response, sockJsPath, this.handler); @@ -149,15 +174,6 @@ public class SockJsServiceTests extends AbstractHttpRequestTests { assertEquals(httpStatus.value(), this.servletResponse.getStatus()); } - @Test - public void handleEmptyContentType() throws Exception { - - servletRequest.setContentType(""); - handleRequest("GET", "/echo/info", HttpStatus.OK); - - assertEquals("Invalid/empty content should have been ignored", 200, this.servletResponse.getStatus()); - } - private static class TestSockJsService extends AbstractSockJsService {