diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/CorsRegistration.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/CorsRegistration.java
new file mode 100644
index 00000000000..eb0bf1534e3
--- /dev/null
+++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/CorsRegistration.java
@@ -0,0 +1,101 @@
+/*
+ * Copyright 2002-2016 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.reactive.config;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import org.springframework.http.HttpMethod;
+import org.springframework.web.bind.annotation.CrossOrigin;
+import org.springframework.web.cors.CorsConfiguration;
+
+/**
+ * {@code CorsRegistration} assists with the creation of a
+ * {@link CorsConfiguration} instance mapped to a path pattern.
+ *
+ *
If no path pattern is specified, cross-origin request handling is
+ * mapped to {@code "/**"}.
+ *
+ *
By default, all origins, all headers, credentials and {@code GET},
+ * {@code HEAD}, and {@code POST} methods are allowed, and the max age is
+ * set to 30 minutes.
+ *
+ * @author Sebastien Deleuze
+ * @author Sam Brannen
+ * @since 5.0
+ * @see CorsConfiguration
+ * @see CorsRegistry
+ */
+public class CorsRegistration {
+
+ private final String pathPattern;
+
+ private final CorsConfiguration config;
+
+
+ public CorsRegistration(String pathPattern) {
+ this.pathPattern = pathPattern;
+ // Same implicit default values as the @CrossOrigin annotation + allows simple methods
+ this.config = new CorsConfiguration();
+ this.config.setAllowedOrigins(Arrays.asList(CrossOrigin.DEFAULT_ORIGINS));
+ this.config.setAllowedMethods(Arrays.asList(HttpMethod.GET.name(),
+ HttpMethod.HEAD.name(), HttpMethod.POST.name()));
+ this.config.setAllowedHeaders(Arrays.asList(CrossOrigin.DEFAULT_ALLOWED_HEADERS));
+ this.config.setAllowCredentials(CrossOrigin.DEFAULT_ALLOW_CREDENTIALS);
+ this.config.setMaxAge(CrossOrigin.DEFAULT_MAX_AGE);
+ }
+
+
+ public CorsRegistration allowedOrigins(String... origins) {
+ this.config.setAllowedOrigins(new ArrayList<>(Arrays.asList(origins)));
+ return this;
+ }
+
+ public CorsRegistration allowedMethods(String... methods) {
+ this.config.setAllowedMethods(new ArrayList<>(Arrays.asList(methods)));
+ return this;
+ }
+
+ public CorsRegistration allowedHeaders(String... headers) {
+ this.config.setAllowedHeaders(new ArrayList<>(Arrays.asList(headers)));
+ return this;
+ }
+
+ public CorsRegistration exposedHeaders(String... headers) {
+ this.config.setExposedHeaders(new ArrayList<>(Arrays.asList(headers)));
+ return this;
+ }
+
+ public CorsRegistration maxAge(long maxAge) {
+ this.config.setMaxAge(maxAge);
+ return this;
+ }
+
+ public CorsRegistration allowCredentials(boolean allowCredentials) {
+ this.config.setAllowCredentials(allowCredentials);
+ return this;
+ }
+
+ protected String getPathPattern() {
+ return this.pathPattern;
+ }
+
+ protected CorsConfiguration getCorsConfiguration() {
+ return this.config;
+ }
+
+}
diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/CorsRegistry.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/CorsRegistry.java
new file mode 100644
index 00000000000..f7f11e7a1dc
--- /dev/null
+++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/CorsRegistry.java
@@ -0,0 +1,61 @@
+/*
+ * Copyright 2002-2016 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.reactive.config;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.springframework.web.cors.CorsConfiguration;
+
+/**
+ * {@code CorsRegistry} assists with the registration of {@link CorsConfiguration}
+ * mapped to a path pattern.
+ *
+ * @author Sebastien Deleuze
+ * @since 5.0
+ */
+public class CorsRegistry {
+
+ private final List registrations = new ArrayList<>();
+
+
+ /**
+ * Enable cross origin request handling for the specified path pattern.
+ *
+ * Exact path mapping URIs (such as {@code "/admin"}) are supported as
+ * well as Ant-style path patterns (such as {@code "/admin/**"}).
+ *
+ *
By default, all origins, all headers, credentials and {@code GET},
+ * {@code HEAD}, and {@code POST} methods are allowed, and the max age
+ * is set to 30 minutes.
+ */
+ public CorsRegistration addMapping(String pathPattern) {
+ CorsRegistration registration = new CorsRegistration(pathPattern);
+ this.registrations.add(registration);
+ return registration;
+ }
+
+ protected Map getCorsConfigurations() {
+ Map configs = new LinkedHashMap<>(this.registrations.size());
+ for (CorsRegistration registration : this.registrations) {
+ configs.put(registration.getPathPattern(), registration.getCorsConfiguration());
+ }
+ return configs;
+ }
+}
diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/WebReactiveConfiguration.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/WebReactiveConfiguration.java
index b73398fd00d..2a92ab463a5 100644
--- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/WebReactiveConfiguration.java
+++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/config/WebReactiveConfiguration.java
@@ -54,6 +54,7 @@ import org.springframework.http.codec.xml.Jaxb2XmlEncoder;
import org.springframework.util.ClassUtils;
import org.springframework.validation.Errors;
import org.springframework.validation.Validator;
+import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.reactive.accept.CompositeContentTypeResolver;
import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder;
@@ -95,6 +96,8 @@ public class WebReactiveConfiguration implements ApplicationContextAware {
private List> messageWriters;
+ private Map corsConfigurations;
+
private ApplicationContext applicationContext;
@@ -113,6 +116,7 @@ public class WebReactiveConfiguration implements ApplicationContextAware {
RequestMappingHandlerMapping mapping = createRequestMappingHandlerMapping();
mapping.setOrder(0);
mapping.setContentTypeResolver(mvcContentTypeResolver());
+ mapping.setCorsConfigurations(getCorsConfigurations());
PathMatchConfigurer configurer = getPathMatchConfigurer();
if (configurer.isUseSuffixPatternMatch() != null) {
@@ -440,6 +444,22 @@ public class WebReactiveConfiguration implements ApplicationContextAware {
protected void configureViewResolvers(ViewResolverRegistry registry) {
}
+ protected final Map getCorsConfigurations() {
+ if (this.corsConfigurations == null) {
+ CorsRegistry registry = new CorsRegistry();
+ addCorsMappings(registry);
+ this.corsConfigurations = registry.getCorsConfigurations();
+ }
+ return this.corsConfigurations;
+ }
+
+ /**
+ * Override this method to configure cross origin requests processing.
+ * @see CorsRegistry
+ */
+ protected void addCorsMappings(CorsRegistry registry) {
+ }
+
private static final class EmptyHandlerMapping extends AbstractHandlerMapping {
diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java
index 7dac4edc954..b009b23cc73 100644
--- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java
+++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractHandlerMapping.java
@@ -15,12 +15,24 @@
*/
package org.springframework.web.reactive.handler;
+import java.util.Map;
+
+import reactor.core.publisher.Mono;
+
import org.springframework.context.support.ApplicationObjectSupport;
import org.springframework.core.Ordered;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.Assert;
import org.springframework.util.PathMatcher;
+import org.springframework.web.cors.CorsConfiguration;
+import org.springframework.web.cors.reactive.CorsConfigurationSource;
+import org.springframework.web.cors.reactive.CorsProcessor;
+import org.springframework.web.cors.reactive.CorsUtils;
+import org.springframework.web.cors.reactive.DefaultCorsProcessor;
+import org.springframework.web.cors.reactive.UrlBasedCorsConfigurationSource;
import org.springframework.web.reactive.HandlerMapping;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebHandler;
import org.springframework.web.util.HttpRequestPathHelper;
/**
@@ -39,8 +51,9 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport
private PathMatcher pathMatcher = new AntPathMatcher();
+ protected CorsProcessor corsProcessor = new DefaultCorsProcessor();
- // TODO: CORS
+ protected final UrlBasedCorsConfigurationSource corsConfigSource = new UrlBasedCorsConfigurationSource();
/**
* Specify the order value for this HandlerMapping bean.
@@ -91,7 +104,7 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport
public void setPathMatcher(PathMatcher pathMatcher) {
Assert.notNull(pathMatcher, "PathMatcher must not be null");
this.pathMatcher = pathMatcher;
- // this.corsConfigSource.setPathMatcher(pathMatcher);
+ this.corsConfigSource.setPathMatcher(pathMatcher);
}
/**
@@ -102,4 +115,62 @@ public abstract class AbstractHandlerMapping extends ApplicationObjectSupport
return this.pathMatcher;
}
+ /**
+ * Configure a custom {@link CorsProcessor} to use to apply the matched
+ * {@link CorsConfiguration} for a request. By default {@link DefaultCorsProcessor} is used.
+ */
+ public void setCorsProcessor(CorsProcessor corsProcessor) {
+ Assert.notNull(corsProcessor, "CorsProcessor must not be null");
+ this.corsProcessor = corsProcessor;
+ }
+
+ /**
+ * Return the configured {@link CorsProcessor}.
+ */
+ public CorsProcessor getCorsProcessor() {
+ return this.corsProcessor;
+ }
+
+ /**
+ * Set "global" CORS configuration based on URL patterns. By default the first
+ * matching URL pattern is combined with the CORS configuration for the
+ * handler, if any.
+ */
+ public void setCorsConfigurations(Map corsConfigurations) {
+ this.corsConfigSource.setCorsConfigurations(corsConfigurations);
+ }
+
+ /**
+ * Get the CORS configuration.
+ */
+ public Map getCorsConfigurations() {
+ return this.corsConfigSource.getCorsConfigurations();
+ }
+
+ protected CorsConfiguration getCorsConfiguration(Object handler, ServerWebExchange exchange) {
+ if (handler != null && handler instanceof CorsConfigurationSource) {
+ return ((CorsConfigurationSource) handler).getCorsConfiguration(exchange);
+ }
+ return null;
+ }
+
+ protected Object processCorsRequest(ServerWebExchange exchange, Object handler) {
+ if (CorsUtils.isCorsRequest(exchange.getRequest())) {
+ CorsConfiguration globalConfig = this.corsConfigSource.getCorsConfiguration(exchange);
+ CorsConfiguration handlerConfig = getCorsConfiguration(handler, exchange);
+ CorsConfiguration config = (globalConfig != null ? globalConfig.combine(handlerConfig) : handlerConfig);
+ if (!corsProcessor.processRequest(config, exchange) || CorsUtils.isPreFlightRequest(exchange.getRequest())) {
+ return new NoOpHandler();
+ }
+ }
+ return handler;
+ }
+
+ private class NoOpHandler implements WebHandler {
+ @Override
+ public Mono handle(ServerWebExchange exchange) {
+ return Mono.empty();
+ }
+ }
+
}
diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java
index 1cc674bb606..bb87cee7078 100644
--- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java
+++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/handler/AbstractUrlHandlerMapping.java
@@ -101,6 +101,7 @@ public abstract class AbstractUrlHandlerMapping extends AbstractHandlerMapping {
Object handler = null;
try {
handler = lookupHandler(lookupPath, exchange);
+ handler = processCorsRequest(exchange, handler);
}
catch (Exception ex) {
return Mono.error(ex);
diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/condition/ConsumesRequestCondition.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/condition/ConsumesRequestCondition.java
index e69f12f8f49..37fafc379a6 100644
--- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/condition/ConsumesRequestCondition.java
+++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/condition/ConsumesRequestCondition.java
@@ -27,6 +27,7 @@ import java.util.Set;
import org.springframework.http.InvalidMediaTypeException;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.cors.reactive.CorsUtils;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.UnsupportedMediaTypeStatusException;
@@ -43,7 +44,7 @@ import org.springframework.web.server.UnsupportedMediaTypeStatusException;
*/
public final class ConsumesRequestCondition extends AbstractRequestCondition {
-// private final static ConsumesRequestCondition PRE_FLIGHT_MATCH = new ConsumesRequestCondition();
+ private final static ConsumesRequestCondition PRE_FLIGHT_MATCH = new ConsumesRequestCondition();
private final List expressions;
@@ -160,9 +161,9 @@ public final class ConsumesRequestCondition extends AbstractRequestCondition {
-// private final static HeadersRequestCondition PRE_FLIGHT_MATCH = new HeadersRequestCondition();
+ private final static HeadersRequestCondition PRE_FLIGHT_MATCH = new HeadersRequestCondition();
private final Set expressions;
@@ -107,9 +108,9 @@ public final class HeadersRequestCondition extends AbstractRequestCondition {
-// private final static ProducesRequestCondition PRE_FLIGHT_MATCH = new ProducesRequestCondition();
+ private final static ProducesRequestCondition PRE_FLIGHT_MATCH = new ProducesRequestCondition();
private final List MEDIA_TYPE_ALL_LIST =
@@ -182,9 +183,9 @@ public final class ProducesRequestCondition extends AbstractRequestCondition extends AbstractHandlerMap
*/
private static final String SCOPED_TARGET_NAME_PREFIX = "scopedTarget.";
+ private static final HandlerMethod PREFLIGHT_AMBIGUOUS_MATCH =
+ new HandlerMethod(new EmptyHandler(), ClassUtils.getMethod(EmptyHandler.class, "handle"));
+
+ private static final CorsConfiguration ALLOW_CORS_CONFIG = new CorsConfiguration();
+
+ static {
+ ALLOW_CORS_CONFIG.addAllowedOrigin("*");
+ ALLOW_CORS_CONFIG.addAllowedMethod("*");
+ ALLOW_CORS_CONFIG.addAllowedHeader("*");
+ ALLOW_CORS_CONFIG.setAllowCredentials(true);
+ }
+
private final MappingRegistry mappingRegistry = new MappingRegistry();
@@ -212,6 +230,13 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap
return handlerMethod;
}
+ /**
+ * Extract and return the CORS configuration for the mapping.
+ */
+ protected CorsConfiguration initCorsConfiguration(Object handler, Method method, T mapping) {
+ return null;
+ }
+
/**
* Invoked after all handler methods have been detected.
* @param handlerMethods a read-only map with handler methods and mappings.
@@ -249,7 +274,10 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap
logger.debug("Did not find handler method for [" + lookupPath + "]");
}
}
- return (handlerMethod != null ? Mono.just(handlerMethod.createWithResolvedBean()) : Mono.empty());
+ if (handlerMethod != null) {
+ handlerMethod = handlerMethod.createWithResolvedBean();
+ }
+ return Mono.justOrEmpty(processCorsRequest(exchange, handlerMethod));
}
finally {
this.mappingRegistry.releaseReadLock();
@@ -287,6 +315,9 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap
}
Match bestMatch = matches.get(0);
if (matches.size() > 1) {
+ if (CorsUtils.isPreFlightRequest(exchange.getRequest())) {
+ return PREFLIGHT_AMBIGUOUS_MATCH;
+ }
Match secondBestMatch = matches.get(1);
if (comparator.compare(bestMatch, secondBestMatch) == 0) {
Method m1 = bestMatch.handlerMethod.getMethod();
@@ -335,6 +366,22 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap
return null;
}
+ @Override
+ protected CorsConfiguration getCorsConfiguration(Object handler, ServerWebExchange exchange) {
+ CorsConfiguration corsConfig = super.getCorsConfiguration(handler, exchange);
+ if (handler instanceof HandlerMethod) {
+ HandlerMethod handlerMethod = (HandlerMethod) handler;
+ if (handlerMethod.equals(PREFLIGHT_AMBIGUOUS_MATCH)) {
+ return AbstractHandlerMethodMapping.ALLOW_CORS_CONFIG;
+ }
+ else {
+ CorsConfiguration corsConfigFromMethod = this.mappingRegistry.getCorsConfiguration(handlerMethod);
+ corsConfig = (corsConfig != null ? corsConfig.combine(corsConfigFromMethod) : corsConfigFromMethod);
+ }
+ }
+ return corsConfig;
+ }
+
// Abstract template methods
@@ -392,6 +439,9 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap
private final MultiValueMap urlLookup = new LinkedMultiValueMap<>();
+ private final Map corsLookup =
+ new ConcurrentHashMap<>();
+
private final ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock();
/**
@@ -410,6 +460,14 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap
return this.urlLookup.get(urlPath);
}
+ /**
+ * Return CORS configuration. Thread-safe for concurrent use.
+ */
+ public CorsConfiguration getCorsConfiguration(HandlerMethod handlerMethod) {
+ HandlerMethod original = handlerMethod.getResolvedFromHandlerMethod();
+ return this.corsLookup.get(original != null ? original : handlerMethod);
+ }
+
/**
* Acquire the read lock when using getMappings and getMappingsByUrl.
*/
@@ -440,6 +498,11 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap
this.urlLookup.add(url, mapping);
}
+ CorsConfiguration corsConfig = initCorsConfiguration(handler, method, mapping);
+ if (corsConfig != null) {
+ this.corsLookup.put(handlerMethod, corsConfig);
+ }
+
this.registry.put(mapping, new MappingRegistration<>(mapping, handlerMethod, directUrls));
}
finally {
@@ -486,6 +549,7 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap
}
}
}
+ this.corsLookup.remove(definition.getHandlerMethod());
}
finally {
this.readWriteLock.writeLock().unlock();
@@ -561,4 +625,11 @@ public abstract class AbstractHandlerMethodMapping extends AbstractHandlerMap
}
}
+ private static class EmptyHandler {
+
+ public void handle() {
+ throw new UnsupportedOperationException("not implemented");
+ }
+ }
+
}
diff --git a/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java b/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java
index b510510cba2..e87cf478553 100644
--- a/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java
+++ b/spring-web-reactive/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestMappingHandlerMapping.java
@@ -18,14 +18,20 @@ package org.springframework.web.reactive.result.method.annotation;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method;
+import java.util.Arrays;
import java.util.Set;
import org.springframework.context.EmbeddedValueResolverAware;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.stereotype.Controller;
import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
import org.springframework.util.StringValueResolver;
+import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RequestMethod;
+import org.springframework.web.cors.CorsConfiguration;
+import org.springframework.web.method.HandlerMethod;
import org.springframework.web.reactive.accept.RequestedContentTypeResolverBuilder;
import org.springframework.web.reactive.accept.RequestedContentTypeResolver;
import org.springframework.web.reactive.result.condition.RequestCondition;
@@ -273,4 +279,76 @@ public class RequestMappingHandlerMapping extends RequestMappingInfoHandlerMappi
}
}
+ @Override
+ protected CorsConfiguration initCorsConfiguration(Object handler, Method method, RequestMappingInfo mappingInfo) {
+ HandlerMethod handlerMethod = createHandlerMethod(handler, method);
+ CrossOrigin typeAnnotation = AnnotatedElementUtils.findMergedAnnotation(handlerMethod.getBeanType(), CrossOrigin.class);
+ CrossOrigin methodAnnotation = AnnotatedElementUtils.findMergedAnnotation(method, CrossOrigin.class);
+
+ if (typeAnnotation == null && methodAnnotation == null) {
+ return null;
+ }
+
+ CorsConfiguration config = new CorsConfiguration();
+ updateCorsConfig(config, typeAnnotation);
+ updateCorsConfig(config, methodAnnotation);
+
+ if (CollectionUtils.isEmpty(config.getAllowedOrigins())) {
+ config.setAllowedOrigins(Arrays.asList(CrossOrigin.DEFAULT_ORIGINS));
+ }
+ if (CollectionUtils.isEmpty(config.getAllowedMethods())) {
+ for (RequestMethod allowedMethod : mappingInfo.getMethodsCondition().getMethods()) {
+ config.addAllowedMethod(allowedMethod.name());
+ }
+ }
+ if (CollectionUtils.isEmpty(config.getAllowedHeaders())) {
+ config.setAllowedHeaders(Arrays.asList(CrossOrigin.DEFAULT_ALLOWED_HEADERS));
+ }
+ if (config.getAllowCredentials() == null) {
+ config.setAllowCredentials(CrossOrigin.DEFAULT_ALLOW_CREDENTIALS);
+ }
+ if (config.getMaxAge() == null) {
+ config.setMaxAge(CrossOrigin.DEFAULT_MAX_AGE);
+ }
+ return config;
+ }
+
+ private void updateCorsConfig(CorsConfiguration config, CrossOrigin annotation) {
+ if (annotation == null) {
+ return;
+ }
+ for (String origin : annotation.origins()) {
+ config.addAllowedOrigin(resolveCorsAnnotationValue(origin));
+ }
+ for (RequestMethod method : annotation.methods()) {
+ config.addAllowedMethod(method.name());
+ }
+ for (String header : annotation.allowedHeaders()) {
+ config.addAllowedHeader(resolveCorsAnnotationValue(header));
+ }
+ for (String header : annotation.exposedHeaders()) {
+ config.addExposedHeader(resolveCorsAnnotationValue(header));
+ }
+
+ String allowCredentials = resolveCorsAnnotationValue(annotation.allowCredentials());
+ if ("true".equalsIgnoreCase(allowCredentials)) {
+ config.setAllowCredentials(true);
+ }
+ else if ("false".equalsIgnoreCase(allowCredentials)) {
+ config.setAllowCredentials(false);
+ }
+ else if (!allowCredentials.isEmpty()) {
+ throw new IllegalStateException("@CrossOrigin's allowCredentials value must be \"true\", \"false\", " +
+ "or an empty string (\"\"): current value is [" + allowCredentials + "]");
+ }
+
+ if (annotation.maxAge() >= 0 && config.getMaxAge() == null) {
+ config.setMaxAge(annotation.maxAge());
+ }
+ }
+
+ private String resolveCorsAnnotationValue(String value) {
+ return (this.embeddedValueResolver != null ? this.embeddedValueResolver.resolveStringValue(value) : value);
+ }
+
}
diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/config/CorsRegistryTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/config/CorsRegistryTests.java
new file mode 100644
index 00000000000..c131c28b02f
--- /dev/null
+++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/config/CorsRegistryTests.java
@@ -0,0 +1,71 @@
+/*
+ * Copyright 2002-2015 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.reactive.config;
+
+import java.util.Arrays;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.springframework.web.cors.CorsConfiguration;
+
+/**
+ * Test fixture with a {@link CorsRegistry}.
+ *
+ * @author Sebastien Deleuze
+ */
+public class CorsRegistryTests {
+
+ private CorsRegistry registry;
+
+ @Before
+ public void setUp() {
+ this.registry = new CorsRegistry();
+ }
+
+ @Test
+ public void noMapping() {
+ assertTrue(this.registry.getCorsConfigurations().isEmpty());
+ }
+
+ @Test
+ public void multipleMappings() {
+ this.registry.addMapping("/foo");
+ this.registry.addMapping("/bar");
+ assertEquals(2, this.registry.getCorsConfigurations().size());
+ }
+
+ @Test
+ public void customizedMapping() {
+ this.registry.addMapping("/foo").allowedOrigins("http://domain2.com", "http://domain2.com")
+ .allowedMethods("DELETE").allowCredentials(false).allowedHeaders("header1", "header2")
+ .exposedHeaders("header3", "header4").maxAge(3600);
+ Map configs = this.registry.getCorsConfigurations();
+ assertEquals(1, configs.size());
+ CorsConfiguration config = configs.get("/foo");
+ assertEquals(Arrays.asList("http://domain2.com", "http://domain2.com"), config.getAllowedOrigins());
+ assertEquals(Arrays.asList("DELETE"), config.getAllowedMethods());
+ assertEquals(Arrays.asList("header1", "header2"), config.getAllowedHeaders());
+ assertEquals(Arrays.asList("header3", "header4"), config.getExposedHeaders());
+ assertEquals(false, config.getAllowCredentials());
+ assertEquals(Long.valueOf(3600), config.getMaxAge());
+ }
+
+}
diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/handler/CorsAbstractUrlHandlerMappingTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/handler/CorsAbstractUrlHandlerMappingTests.java
new file mode 100644
index 00000000000..3ed1bc6efc5
--- /dev/null
+++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/handler/CorsAbstractUrlHandlerMappingTests.java
@@ -0,0 +1,182 @@
+/*
+ * Copyright 2002-2016 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.web.reactive.handler;
+
+import java.net.URISyntaxException;
+import java.util.Collections;
+
+import static org.junit.Assert.*;
+import static org.junit.Assert.assertSame;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.springframework.context.annotation.AnnotationConfigApplicationContext;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
+import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
+import org.springframework.web.cors.CorsConfiguration;
+import org.springframework.web.cors.reactive.CorsConfigurationSource;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.adapter.DefaultServerWebExchange;
+import org.springframework.web.server.session.MockWebSessionManager;
+import org.springframework.web.server.session.WebSessionManager;
+
+/**
+ * Unit tests for CORS support at {@link AbstractUrlHandlerMapping} level.
+ *
+ * @author Sebastien Deleuze
+ * @author Rossen Stoyanchev
+ */
+public class CorsAbstractUrlHandlerMappingTests {
+
+ private AnnotationConfigApplicationContext wac;
+
+ private TestUrlHandlerMapping handlerMapping;
+
+ private Object mainController;
+
+ private CorsAwareHandler corsConfigurationSourceController;
+
+ @Before
+ public void setup() {
+ wac = new AnnotationConfigApplicationContext();
+ wac.register(WebConfig.class);
+ wac.refresh();
+
+ handlerMapping = (TestUrlHandlerMapping) wac.getBean("handlerMapping");
+ mainController = wac.getBean("mainController");
+ corsConfigurationSourceController = (CorsAwareHandler) wac.getBean("corsConfigurationSourceController");
+ }
+
+ @Test
+ public void actualRequestWithoutCorsConfigurationProvider() throws Exception {
+ ServerWebExchange exchange = createExchange(HttpMethod.GET, "/welcome.html", "http://domain2.com", "GET");
+ Object actual = handlerMapping.getHandler(exchange).block();
+ assertNotNull(actual);
+ assertSame(mainController, actual);
+ }
+
+ @Test
+ public void preflightRequestWithoutCorsConfigurationProvider() throws Exception {
+ ServerWebExchange exchange = createExchange(HttpMethod.OPTIONS, "/welcome.html", "http://domain2.com", "GET");
+ Object actual = handlerMapping.getHandler(exchange).block();
+ assertNotNull(actual);
+ assertEquals("NoOpHandler", actual.getClass().getSimpleName());
+ assertNull(exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ }
+
+ @Test
+ public void actualRequestWithCorsConfigurationProvider() throws Exception {
+ ServerWebExchange exchange = createExchange(HttpMethod.GET, "/cors.html", "http://domain2.com", "GET");
+ Object actual = handlerMapping.getHandler(exchange).block();
+ assertNotNull(actual);
+ assertSame(corsConfigurationSourceController, actual);
+ CorsConfiguration config = ((CorsConfigurationSource)actual).getCorsConfiguration(createExchange(HttpMethod.GET, "", "",""));
+ assertNotNull(config);
+ assertArrayEquals(config.getAllowedOrigins().toArray(), new String[]{"*"});
+ assertEquals("*", exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ }
+
+ @Test
+ public void preflightRequestWithCorsConfigurationProvider() throws Exception {
+ ServerWebExchange exchange = createExchange(HttpMethod.OPTIONS, "/cors.html", "http://domain2.com", "GET");
+ Object actual = handlerMapping.getHandler(exchange).block();
+ assertNotNull(actual);
+ assertEquals("NoOpHandler", actual.getClass().getSimpleName());
+ assertEquals("*", exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ }
+
+ @Test
+ public void actualRequestWithMappedCorsConfiguration() throws Exception {
+ CorsConfiguration mappedConfig = new CorsConfiguration();
+ mappedConfig.addAllowedOrigin("*");
+ this.handlerMapping.setCorsConfigurations(Collections.singletonMap("/welcome.html", mappedConfig));
+
+ ServerWebExchange exchange = createExchange(HttpMethod.GET, "/welcome.html", "http://domain2.com", "GET");
+ Object actual = handlerMapping.getHandler(exchange).block();
+ assertNotNull(actual);
+ assertSame(mainController, actual);
+ assertEquals("*", exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ }
+
+ @Test
+ public void preflightRequestWithMappedCorsConfiguration() throws Exception {
+ CorsConfiguration mappedConfig = new CorsConfiguration();
+ mappedConfig.addAllowedOrigin("*");
+ this.handlerMapping.setCorsConfigurations(Collections.singletonMap("/welcome.html", mappedConfig));
+
+ ServerWebExchange exchange = createExchange(HttpMethod.OPTIONS, "/welcome.html", "http://domain2.com", "GET");
+ Object actual = handlerMapping.getHandler(exchange).block();
+ assertNotNull(actual);
+ assertEquals("NoOpHandler", actual.getClass().getSimpleName());
+ assertEquals("*", exchange.getResponse().getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ }
+
+
+ private ServerWebExchange createExchange(HttpMethod method, String path, String origin,
+ String accessControlRequestMethod) throws URISyntaxException {
+
+ ServerHttpRequest request = new MockServerHttpRequest(method, "http://localhost" + path);
+ request.getHeaders().add(HttpHeaders.ORIGIN, origin);
+ request.getHeaders().add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, accessControlRequestMethod);
+ WebSessionManager sessionManager = new MockWebSessionManager();
+ return new DefaultServerWebExchange(request, new MockServerHttpResponse(), sessionManager);
+ }
+
+
+ @Configuration
+ static class WebConfig {
+
+ @Bean @SuppressWarnings("unused")
+ public TestUrlHandlerMapping handlerMapping() {
+ TestUrlHandlerMapping hm = new TestUrlHandlerMapping();
+ hm.setUseTrailingSlashMatch(true);
+ hm.registerHandler("/welcome.html", mainController());
+ hm.registerHandler("/cors.html", corsConfigurationSourceController());
+ return hm;
+ }
+
+ @Bean
+ public Object mainController() {
+ return new Object();
+ }
+
+ @Bean
+ public CorsAwareHandler corsConfigurationSourceController() {
+ return new CorsAwareHandler();
+ }
+
+ }
+
+ static class TestUrlHandlerMapping extends AbstractUrlHandlerMapping {
+
+ }
+
+ static class CorsAwareHandler implements CorsConfigurationSource {
+
+ @Override
+ public CorsConfiguration getCorsConfiguration(ServerWebExchange exchange) {
+ CorsConfiguration config = new CorsConfiguration();
+ config.addAllowedOrigin("*");
+ return config;
+ }
+ }
+
+}
diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/AbstractRequestMappingIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/AbstractRequestMappingIntegrationTests.java
index 1d79704359c..b2aa35a3bfd 100644
--- a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/AbstractRequestMappingIntegrationTests.java
+++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/AbstractRequestMappingIntegrationTests.java
@@ -27,6 +27,7 @@ import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.reactive.DispatcherHandler;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
+import org.springframework.web.server.handler.ResponseStatusExceptionHandler;
import static org.springframework.http.RequestEntity.get;
@@ -46,6 +47,7 @@ public abstract class AbstractRequestMappingIntegrationTests extends AbstractHtt
this.applicationContext = initApplicationContext();
return WebHttpHandlerBuilder
.webHandler(new DispatcherHandler(this.applicationContext))
+ .exceptionHandlers(new ResponseStatusExceptionHandler())
.build();
}
diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/CorsConfigurationIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/CorsConfigurationIntegrationTests.java
new file mode 100644
index 00000000000..b461559e2a2
--- /dev/null
+++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/CorsConfigurationIntegrationTests.java
@@ -0,0 +1,183 @@
+/*
+ * Copyright 2002-2016 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.reactive.result.method.annotation;
+
+import static org.junit.Assert.*;
+import org.junit.Test;
+
+import org.springframework.context.ApplicationContext;
+import org.springframework.context.annotation.AnnotationConfigApplicationContext;
+import org.springframework.context.annotation.ComponentScan;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.http.HttpEntity;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.ResponseEntity;
+import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.client.HttpClientErrorException;
+import org.springframework.web.client.RestTemplate;
+import org.springframework.web.reactive.config.CorsRegistry;
+import org.springframework.web.reactive.config.WebReactiveConfiguration;
+
+/**
+ * @author Sebastien Deleuze
+ */
+public class CorsConfigurationIntegrationTests extends AbstractRequestMappingIntegrationTests {
+
+ // JDK default HTTP client blacklist headers like Origin
+ private RestTemplate restTemplate = new RestTemplate(new HttpComponentsClientHttpRequestFactory());
+
+ @Override
+ protected ApplicationContext initApplicationContext() {
+ AnnotationConfigApplicationContext wac = new AnnotationConfigApplicationContext();
+ wac.register(WebConfig.class);
+ wac.refresh();
+ return wac;
+ }
+
+ @Override
+ RestTemplate getRestTemplate() {
+ return this.restTemplate;
+ }
+
+ @Test
+ public void actualRequestWithCorsEnabled() throws Exception {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://localhost:9000");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/cors"),
+ HttpMethod.GET, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertEquals("http://localhost:9000", entity.getHeaders().getAccessControlAllowOrigin());
+ assertEquals("cors", entity.getBody());
+ }
+
+ @Test
+ public void actualRequestWithCorsRejected() throws Exception {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://localhost:9000");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ try {
+ this.restTemplate.exchange(getUrl("/cors-restricted"), HttpMethod.GET,
+ requestEntity, String.class);
+ }
+ catch (HttpClientErrorException e) {
+ assertEquals(HttpStatus.FORBIDDEN, e.getStatusCode());
+ return;
+ }
+ fail();
+ }
+
+ @Test
+ public void actualRequestWithoutCorsEnabled() throws Exception {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://localhost:9000");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/welcome"),
+ HttpMethod.GET, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertNull(entity.getHeaders().getAccessControlAllowOrigin());
+ assertEquals("welcome", entity.getBody());
+ }
+
+ @Test
+ public void preflightRequestWithCorsEnabled() throws Exception {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://localhost:9000");
+ headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/cors"),
+ HttpMethod.OPTIONS, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertEquals("http://localhost:9000", entity.getHeaders().getAccessControlAllowOrigin());
+ }
+
+ @Test
+ public void preflightRequestWithCorsRejected() throws Exception {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://localhost:9000");
+ headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ try {
+ this.restTemplate.exchange(getUrl("/cors-restricted"), HttpMethod.OPTIONS,
+ requestEntity, String.class);
+ }
+ catch (HttpClientErrorException e) {
+ assertEquals(HttpStatus.FORBIDDEN, e.getStatusCode());
+ return;
+ }
+ fail();
+ }
+
+ @Test
+ public void preflightRequestWithoutCorsEnabled() throws Exception {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://localhost:9000");
+ headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ try {
+ this.restTemplate.exchange(getUrl("/welcome"), HttpMethod.OPTIONS,
+ requestEntity, String.class);
+ }
+ catch (HttpClientErrorException e) {
+ assertEquals(HttpStatus.FORBIDDEN, e.getStatusCode());
+ return;
+ }
+ fail();
+ }
+
+ private String getUrl(String path) {
+ return "http://localhost:" + this.port + path;
+ }
+
+
+ @Configuration
+ @ComponentScan(resourcePattern = "**/CorsConfigurationIntegrationTests*.class")
+ @SuppressWarnings({"unused", "WeakerAccess"})
+ static class WebConfig extends WebReactiveConfiguration {
+
+ @Override
+ protected void addCorsMappings(CorsRegistry registry) {
+ registry.addMapping("/cors-restricted").allowedOrigins("http://foo");
+ registry.addMapping("/cors");
+ }
+ }
+
+ @RestController
+ static class TestController {
+
+ @GetMapping("/welcome")
+ public String welcome() {
+ return "welcome";
+ }
+
+ @GetMapping("/cors")
+ public String cors() {
+ return "cors";
+ }
+
+ @GetMapping("/cors-restricted")
+ public String corsRestricted() {
+ return "corsRestricted";
+ }
+
+ }
+
+}
diff --git a/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/CrossOriginAnnotationIntegrationTests.java b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/CrossOriginAnnotationIntegrationTests.java
new file mode 100644
index 00000000000..1e5a9fdc0f1
--- /dev/null
+++ b/spring-web-reactive/src/test/java/org/springframework/web/reactive/result/method/annotation/CrossOriginAnnotationIntegrationTests.java
@@ -0,0 +1,345 @@
+/*
+ * Copyright 2002-2016 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.reactive.result.method.annotation;
+
+import java.util.Properties;
+
+import static org.junit.Assert.*;
+import static org.junit.Assert.assertArrayEquals;
+import org.junit.Test;
+
+import org.springframework.context.ApplicationContext;
+import org.springframework.context.annotation.AnnotationConfigApplicationContext;
+import org.springframework.context.annotation.ComponentScan;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.context.support.PropertySourcesPlaceholderConfigurer;
+import org.springframework.core.env.PropertiesPropertySource;
+import org.springframework.http.HttpEntity;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.ResponseEntity;
+import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
+import org.springframework.web.bind.annotation.CrossOrigin;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RequestMethod;
+import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.client.RestTemplate;
+import org.springframework.web.reactive.config.WebReactiveConfiguration;
+
+/**
+ * @author Sebastien Deleuze
+ */
+public class CrossOriginAnnotationIntegrationTests extends AbstractRequestMappingIntegrationTests {
+
+ // JDK default HTTP client blacklist headers like Origin
+ private RestTemplate restTemplate = new RestTemplate(new HttpComponentsClientHttpRequestFactory());
+
+
+ @Override
+ protected ApplicationContext initApplicationContext() {
+ AnnotationConfigApplicationContext wac = new AnnotationConfigApplicationContext();
+ wac.register(WebConfig.class);
+ Properties props = new Properties();
+ props.setProperty("myOrigin", "http://site1.com");
+ wac.getEnvironment().getPropertySources().addFirst(new PropertiesPropertySource("ps", props));
+ wac.register(PropertySourcesPlaceholderConfigurer.class);
+ wac.refresh();
+ return wac;
+ }
+
+ @Override
+ RestTemplate getRestTemplate() {
+ return this.restTemplate;
+ }
+
+ @Test
+ public void actualGetRequestWithoutAnnotation() {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://site1.com");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/no"),
+ HttpMethod.GET, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertNull(entity.getHeaders().getAccessControlAllowOrigin());
+ assertEquals("no", entity.getBody());
+ }
+
+ @Test
+ public void actualPostRequestWithoutAnnotation() {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://site1.com");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/no"),
+ HttpMethod.POST, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertNull(entity.getHeaders().getAccessControlAllowOrigin());
+ assertEquals("no-post", entity.getBody());
+ }
+
+ @Test
+ public void actualRequestWithDefaultAnnotation() {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://site1.com");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/default"),
+ HttpMethod.GET, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin());
+ assertEquals(true, entity.getHeaders().getAccessControlAllowCredentials());
+ assertEquals("default", entity.getBody());
+ }
+
+ @Test
+ public void preflightRequestWithDefaultAnnotation() {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://site1.com");
+ headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/default"),
+ HttpMethod.OPTIONS, requestEntity, Void.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin());
+ assertEquals(1800, entity.getHeaders().getAccessControlMaxAge());
+ assertEquals(true, entity.getHeaders().getAccessControlAllowCredentials());
+ }
+
+ @Test
+ public void actualRequestWithDefaultAnnotationAndNoOrigin() {
+ HttpHeaders headers = new HttpHeaders();
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/default"),
+ HttpMethod.GET, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertNull(entity.getHeaders().getAccessControlAllowOrigin());
+ assertEquals("default", entity.getBody());
+ }
+
+ @Test
+ public void actualRequestWithCustomizedAnnotation() {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://site1.com");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/customized"),
+ HttpMethod.GET, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin());
+ assertEquals(false, entity.getHeaders().getAccessControlAllowCredentials());
+ assertEquals(-1, entity.getHeaders().getAccessControlMaxAge());
+ assertEquals("customized", entity.getBody());
+ }
+
+ @Test
+ public void preflightRequestWithCustomizedAnnotation() {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://site1.com");
+ headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "header1, header2");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/customized"),
+ HttpMethod.OPTIONS, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin());
+ assertArrayEquals(new HttpMethod[] {HttpMethod.GET}, entity.getHeaders().getAccessControlAllowMethods().toArray());
+ assertEquals(false, entity.getHeaders().getAccessControlAllowCredentials());
+ assertArrayEquals(new String[] {"header1", "header2"}, entity.getHeaders().getAccessControlAllowHeaders().toArray());
+ assertArrayEquals(new String[] {"header3", "header4"}, entity.getHeaders().getAccessControlExposeHeaders().toArray());
+ assertEquals(123, entity.getHeaders().getAccessControlMaxAge());
+ }
+
+ @Test
+ public void customOriginDefinedViaValueAttribute() {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://site1.com");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/origin-value-attribute"),
+ HttpMethod.GET, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin());
+ assertEquals("value-attribute", entity.getBody());
+ }
+
+ @Test
+ public void customOriginDefinedViaPlaceholder() {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://site1.com");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/origin-placeholder"),
+ HttpMethod.GET, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin());
+ assertEquals("placeholder", entity.getBody());
+ }
+
+ @Test
+ public void classLevel() {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://site1.com");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/foo"),
+ HttpMethod.GET, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertEquals("*", entity.getHeaders().getAccessControlAllowOrigin());
+ assertEquals(false, entity.getHeaders().getAccessControlAllowCredentials());
+ assertEquals("foo", entity.getBody());
+
+ entity = this.restTemplate.exchange(getUrl("/bar"), HttpMethod.GET, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertEquals("*", entity.getHeaders().getAccessControlAllowOrigin());
+ assertEquals(false, entity.getHeaders().getAccessControlAllowCredentials());
+ assertEquals("bar", entity.getBody());
+
+ entity = this.restTemplate.exchange(getUrl("/baz"), HttpMethod.GET, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin());
+ assertEquals(true, entity.getHeaders().getAccessControlAllowCredentials());
+ assertEquals("baz", entity.getBody());
+ }
+
+ @Test
+ public void ambiguousHeaderPreflightRequest() throws Exception {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://site1.com");
+ headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "header1");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/ambiguous-header"),
+ HttpMethod.OPTIONS, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin());
+ assertArrayEquals(new HttpMethod[] {HttpMethod.GET}, entity.getHeaders().getAccessControlAllowMethods().toArray());
+ assertEquals(true, entity.getHeaders().getAccessControlAllowCredentials());
+ assertArrayEquals(new String[] {"header1"}, entity.getHeaders().getAccessControlAllowHeaders().toArray());
+ }
+
+ @Test
+ public void ambiguousProducesPreflightRequest() throws Exception {
+ HttpHeaders headers = new HttpHeaders();
+ headers.add(HttpHeaders.ORIGIN, "http://site1.com");
+ headers.add(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ HttpEntity> requestEntity = new HttpEntity(headers);
+ ResponseEntity entity = this.restTemplate.exchange(getUrl("/ambiguous-produces"),
+ HttpMethod.OPTIONS, requestEntity, String.class);
+ assertEquals(HttpStatus.OK, entity.getStatusCode());
+ assertEquals("http://site1.com", entity.getHeaders().getAccessControlAllowOrigin());
+ assertArrayEquals(new HttpMethod[] {HttpMethod.GET}, entity.getHeaders().getAccessControlAllowMethods().toArray());
+ assertEquals(true, entity.getHeaders().getAccessControlAllowCredentials());
+ }
+
+ private String getUrl(String path) {
+ return "http://localhost:" + this.port + path;
+ }
+
+
+ @Configuration
+ @ComponentScan(resourcePattern = "**/CrossOriginAnnotationIntegrationTests*")
+ @SuppressWarnings({"unused", "WeakerAccess"})
+ static class WebConfig extends WebReactiveConfiguration {
+
+ }
+
+ @RestController
+ private static class MethodLevelController {
+
+ @RequestMapping(path = "/no", method = RequestMethod.GET)
+ public String noAnnotation() {
+ return "no";
+ }
+
+ @RequestMapping(path = "/no", method = RequestMethod.POST)
+ public String noAnnotationPost() {
+ return "no-post";
+ }
+
+ @CrossOrigin
+ @RequestMapping(path = "/default", method = RequestMethod.GET)
+ public String defaultAnnotation() {
+ return "default";
+ }
+
+ @CrossOrigin
+ @RequestMapping(path = "/default", method = RequestMethod.GET, params = "q")
+ public void defaultAnnotationWithParams() {
+ }
+
+ @CrossOrigin
+ @RequestMapping(path = "/ambiguous-header", method = RequestMethod.GET, headers = "header1=a")
+ public void ambigousHeader1a() {
+ }
+
+ @CrossOrigin
+ @RequestMapping(path = "/ambiguous-header", method = RequestMethod.GET, headers = "header1=b")
+ public void ambigousHeader1b() {
+ }
+
+ @CrossOrigin
+ @RequestMapping(path = "/ambiguous-produces", method = RequestMethod.GET, produces = "application/xml")
+ public String ambigousProducesXml() {
+ return "";
+ }
+
+ @CrossOrigin
+ @RequestMapping(path = "/ambiguous-produces", method = RequestMethod.GET, produces = "application/json")
+ public String ambigousProducesJson() {
+ return "{}";
+ }
+
+ @CrossOrigin(origins = { "http://site1.com", "http://site2.com" }, allowedHeaders = { "header1", "header2" },
+ exposedHeaders = { "header3", "header4" }, methods = RequestMethod.GET, maxAge = 123, allowCredentials = "false")
+ @RequestMapping(path = "/customized", method = { RequestMethod.GET, RequestMethod.POST })
+ public String customized() {
+ return "customized";
+ }
+
+ @CrossOrigin("http://site1.com")
+ @RequestMapping("/origin-value-attribute")
+ public String customOriginDefinedViaValueAttribute() {
+ return "value-attribute";
+ }
+
+ @CrossOrigin("${myOrigin}")
+ @RequestMapping("/origin-placeholder")
+ public String customOriginDefinedViaPlaceholder() {
+ return "placeholder";
+ }
+ }
+
+ @RestController
+ @CrossOrigin(allowCredentials = "false")
+ private static class ClassLevelController {
+
+ @RequestMapping(path = "/foo", method = RequestMethod.GET)
+ public String foo() {
+ return "foo";
+ }
+
+ @CrossOrigin
+ @RequestMapping(path = "/bar", method = RequestMethod.GET)
+ public String bar() {
+ return "bar";
+ }
+
+ @CrossOrigin(allowCredentials = "true")
+ @RequestMapping(path = "/baz", method = RequestMethod.GET)
+ public String baz() {
+ return "baz";
+ }
+
+ }
+
+}
diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsConfigurationSource.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsConfigurationSource.java
new file mode 100644
index 00000000000..c0fc6e8f92f
--- /dev/null
+++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsConfigurationSource.java
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2002-2016 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.cors.reactive;
+
+import org.springframework.web.cors.CorsConfiguration;
+import org.springframework.web.server.ServerWebExchange;
+
+/**
+ * Interface to be implemented by classes (usually HTTP request handlers) that
+ * provides a {@link CorsConfiguration} instance based on the provided reactive request.
+ *
+ * @author Sebastien Deleuze
+ * @since 5.0
+ */
+public interface CorsConfigurationSource {
+
+ /**
+ * Return a {@link CorsConfiguration} based on the incoming request.
+ * @return the associated {@link CorsConfiguration}, or {@code null} if none
+ */
+ CorsConfiguration getCorsConfiguration(ServerWebExchange exchange);
+
+}
diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsProcessor.java
new file mode 100644
index 00000000000..b516b77893e
--- /dev/null
+++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsProcessor.java
@@ -0,0 +1,48 @@
+/*
+ * Copyright 2002-2015 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.cors.reactive;
+
+import reactor.core.publisher.Mono;
+
+import org.springframework.web.cors.CorsConfiguration;
+import org.springframework.web.server.ServerWebExchange;
+
+/**
+ * A strategy that takes a reactive request and a {@link CorsConfiguration} and updates
+ * the response.
+ *
+ * This component is not concerned with how a {@code CorsConfiguration} is
+ * selected but rather takes follow-up actions such as applying CORS validation
+ * checks and either rejecting the response or adding CORS headers to the
+ * response.
+ *
+ * @author Sebastien Deleuze
+ * @author Rossen Stoyanchev
+ * @since 5.0
+ * @see CORS W3C recommandation
+ */
+public interface CorsProcessor {
+
+ /**
+ * Process a request given a {@code CorsConfiguration}.
+ * @param configuration the applicable CORS configuration (possibly {@code null})
+ * @param exchange the current HTTP request / response
+ * @return a {@link Mono} emitting {@code false} if the request is rejected, {@code true} otherwise
+ */
+ boolean processRequest(CorsConfiguration configuration, ServerWebExchange exchange);
+
+}
diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java
new file mode 100644
index 00000000000..4431c40e0c6
--- /dev/null
+++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/CorsUtils.java
@@ -0,0 +1,86 @@
+/*
+ * Copyright 2002-2015 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.cors.reactive;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.util.Assert;
+import org.springframework.web.util.UriComponents;
+import org.springframework.web.util.UriComponentsBuilder;
+
+;
+
+/**
+ * Utility class for CORS reactive request handling based on the
+ * CORS W3C recommendation.
+ *
+ * @author Sebastien Deleuze
+ * @since 5.0
+ */
+public abstract class CorsUtils {
+
+ /**
+ * Returns {@code true} if the request is a valid CORS one.
+ */
+ public static boolean isCorsRequest(ServerHttpRequest request) {
+ return (request.getHeaders().get(HttpHeaders.ORIGIN) != null);
+ }
+
+ /**
+ * Returns {@code true} if the request is a valid CORS pre-flight one.
+ */
+ public static boolean isPreFlightRequest(ServerHttpRequest request) {
+ return (isCorsRequest(request) && HttpMethod.OPTIONS == request.getMethod() &&
+ request.getHeaders().get(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD) != null);
+ }
+
+ /**
+ * Check if the request is a same-origin one, based on {@code Origin}, {@code Host},
+ * {@code Forwarded} and {@code X-Forwarded-Host} headers.
+ * @return {@code true} if the request is a same-origin one, {@code false} in case
+ * of cross-origin request.
+ */
+ public static boolean isSameOrigin(ServerHttpRequest request) {
+ String origin = request.getHeaders().getOrigin();
+ if (origin == null) {
+ return true;
+ }
+ UriComponentsBuilder urlBuilder = UriComponentsBuilder.fromHttpRequest(request);
+ UriComponents actualUrl = urlBuilder.build();
+ String actualHost = actualUrl.getHost();
+ int actualPort = getPort(actualUrl);
+ Assert.notNull(actualHost, "Actual request host must not be null");
+ Assert.isTrue(actualPort != -1, "Actual request port must not be undefined");
+ UriComponents originUrl = UriComponentsBuilder.fromOriginHeader(origin).build();
+ return (actualHost.equals(originUrl.getHost()) && actualPort == getPort(originUrl));
+ }
+
+ private static int getPort(UriComponents uri) {
+ int port = uri.getPort();
+ if (port == -1) {
+ if ("http".equals(uri.getScheme()) || "ws".equals(uri.getScheme())) {
+ port = 80;
+ }
+ else if ("https".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) {
+ port = 443;
+ }
+ }
+ return port;
+ }
+
+}
diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java
new file mode 100644
index 00000000000..f2b7c57858f
--- /dev/null
+++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/DefaultCorsProcessor.java
@@ -0,0 +1,187 @@
+/*
+ * Copyright 2002-2016 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.cors.reactive;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.http.server.reactive.ServerHttpResponse;
+import org.springframework.util.CollectionUtils;
+import org.springframework.web.cors.CorsConfiguration;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.util.WebUtils;
+
+/**
+ * The default implementation of {@link CorsProcessor},
+ * as defined by the CORS W3C recommendation.
+ *
+ *
Note that when input {@link CorsConfiguration} is {@code null}, this
+ * implementation does not reject simple or actual requests outright but simply
+ * avoid adding CORS headers to the response. CORS processing is also skipped
+ * if the response already contains CORS headers, or if the request is detected
+ * as a same-origin one.
+ *
+ * @author Sebastien Deleuze
+ * @author Rossen Stoyanchev
+ * @since 5.0
+ */
+public class DefaultCorsProcessor implements CorsProcessor {
+
+ private static final Log logger = LogFactory.getLog(DefaultCorsProcessor.class);
+
+
+ @Override
+ @SuppressWarnings("resource")
+ public boolean processRequest(CorsConfiguration config, ServerWebExchange exchange) {
+
+ ServerHttpRequest request = exchange.getRequest();
+ ServerHttpResponse response = exchange.getResponse();
+
+ if (!CorsUtils.isCorsRequest(request)) {
+ return true;
+ }
+
+ if (responseHasCors(response)) {
+ logger.debug("Skip CORS processing: response already contains \"Access-Control-Allow-Origin\" header");
+ return true;
+ }
+
+ if (CorsUtils.isSameOrigin(request)) {
+ logger.debug("Skip CORS processing: request is from same origin");
+ return true;
+ }
+
+ boolean preFlightRequest = CorsUtils.isPreFlightRequest(request);
+ if (config == null) {
+ if (preFlightRequest) {
+ rejectRequest(response);
+ return false;
+ }
+ else {
+ return true;
+ }
+ }
+
+ return handleInternal(exchange, config, preFlightRequest);
+ }
+
+ private boolean responseHasCors(ServerHttpResponse response) {
+ return (response.getHeaders().getAccessControlAllowOrigin() != null);
+ }
+
+ /**
+ * Invoked when one of the CORS checks failed.
+ */
+ protected void rejectRequest(ServerHttpResponse response) {
+ response.setStatusCode(HttpStatus.FORBIDDEN);
+ logger.debug("Invalid CORS request");
+ }
+
+ /**
+ * Handle the given request.
+ */
+ protected boolean handleInternal(ServerWebExchange exchange,
+ CorsConfiguration config, boolean preFlightRequest) {
+
+ ServerHttpRequest request = exchange.getRequest();
+ ServerHttpResponse response = exchange.getResponse();
+
+ String requestOrigin = request.getHeaders().getOrigin();
+ String allowOrigin = checkOrigin(config, requestOrigin);
+
+ HttpMethod requestMethod = getMethodToUse(request, preFlightRequest);
+ List allowMethods = checkMethods(config, requestMethod);
+
+ List requestHeaders = getHeadersToUse(request, preFlightRequest);
+ List allowHeaders = checkHeaders(config, requestHeaders);
+
+ if (allowOrigin == null || allowMethods == null || (preFlightRequest && allowHeaders == null)) {
+ rejectRequest(response);
+ return false;
+ }
+
+ HttpHeaders responseHeaders = response.getHeaders();
+ responseHeaders.setAccessControlAllowOrigin(allowOrigin);
+ responseHeaders.add(HttpHeaders.VARY, HttpHeaders.ORIGIN);
+
+ if (preFlightRequest) {
+ responseHeaders.setAccessControlAllowMethods(allowMethods);
+ }
+
+ if (preFlightRequest && !allowHeaders.isEmpty()) {
+ responseHeaders.setAccessControlAllowHeaders(allowHeaders);
+ }
+
+ if (!CollectionUtils.isEmpty(config.getExposedHeaders())) {
+ responseHeaders.setAccessControlExposeHeaders(config.getExposedHeaders());
+ }
+
+ if (Boolean.TRUE.equals(config.getAllowCredentials())) {
+ responseHeaders.setAccessControlAllowCredentials(true);
+ }
+
+ if (preFlightRequest && config.getMaxAge() != null) {
+ responseHeaders.setAccessControlMaxAge(config.getMaxAge());
+ }
+
+ return true;
+ }
+
+ /**
+ * Check the origin and determine the origin for the response. The default
+ * implementation simply delegates to
+ * {@link CorsConfiguration#checkOrigin(String)}.
+ */
+ protected String checkOrigin(CorsConfiguration config, String requestOrigin) {
+ return config.checkOrigin(requestOrigin);
+ }
+
+ /**
+ * Check the HTTP method and determine the methods for the response of a
+ * pre-flight request. The default implementation simply delegates to
+ * {@link CorsConfiguration#checkOrigin(String)}.
+ */
+ protected List checkMethods(CorsConfiguration config, HttpMethod requestMethod) {
+ return config.checkHttpMethod(requestMethod);
+ }
+
+ private HttpMethod getMethodToUse(ServerHttpRequest request, boolean isPreFlight) {
+ return (isPreFlight ? request.getHeaders().getAccessControlRequestMethod() : request.getMethod());
+ }
+
+ /**
+ * Check the headers and determine the headers for the response of a
+ * pre-flight request. The default implementation simply delegates to
+ * {@link CorsConfiguration#checkOrigin(String)}.
+ */
+ protected List checkHeaders(CorsConfiguration config, List requestHeaders) {
+ return config.checkHeaders(requestHeaders);
+ }
+
+ private List getHeadersToUse(ServerHttpRequest request, boolean isPreFlight) {
+ HttpHeaders headers = request.getHeaders();
+ return (isPreFlight ? headers.getAccessControlRequestHeaders() : new ArrayList<>(headers.keySet()));
+ }
+
+}
diff --git a/spring-web/src/main/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSource.java b/spring-web/src/main/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSource.java
new file mode 100644
index 00000000000..c6b78d77e8e
--- /dev/null
+++ b/spring-web/src/main/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSource.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright 2002-2016 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.cors.reactive;
+
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+import org.springframework.util.AntPathMatcher;
+import org.springframework.util.Assert;
+import org.springframework.util.PathMatcher;
+import org.springframework.web.cors.CorsConfiguration;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.util.HttpRequestPathHelper;
+
+/**
+ * Provide a per reactive request {@link CorsConfiguration} instance based on a
+ * collection of {@link CorsConfiguration} mapped on path patterns.
+ *
+ * Exact path mapping URIs (such as {@code "/admin"}) are supported
+ * as well as Ant-style path patterns (such as {@code "/admin/**"}).
+ *
+ * @author Sebastien Deleuze
+ * @since 5.0
+ */
+public class UrlBasedCorsConfigurationSource implements CorsConfigurationSource {
+
+ private final Map corsConfigurations = new LinkedHashMap<>();
+
+ private PathMatcher pathMatcher = new AntPathMatcher();
+
+ private HttpRequestPathHelper pathHelper = new HttpRequestPathHelper();
+
+
+ /**
+ * Set the PathMatcher implementation to use for matching URL paths
+ * against registered URL patterns. Default is AntPathMatcher.
+ * @see AntPathMatcher
+ */
+ public void setPathMatcher(PathMatcher pathMatcher) {
+ Assert.notNull(pathMatcher, "PathMatcher must not be null");
+ this.pathMatcher = pathMatcher;
+ }
+
+ /**
+ * Set if context path and request URI should be URL-decoded. Both are returned
+ * undecoded by the Servlet API, in contrast to the servlet path.
+ * Uses either the request encoding or the default encoding according
+ * to the Servlet spec (ISO-8859-1).
+ * @see HttpRequestPathHelper#setUrlDecode
+ */
+ public void setUrlDecode(boolean urlDecode) {
+ this.pathHelper.setUrlDecode(urlDecode);
+ }
+
+ /**
+ * Set the UrlPathHelper to use for resolution of lookup paths.
+ *
Use this to override the default UrlPathHelper with a custom subclass.
+ */
+ public void setHttpRequestPathHelper(HttpRequestPathHelper pathHelper) {
+ Assert.notNull(pathHelper, "HttpRequestPathHelper must not be null");
+ this.pathHelper = pathHelper;
+ }
+
+ /**
+ * Set CORS configuration based on URL patterns.
+ */
+ public void setCorsConfigurations(Map corsConfigurations) {
+ this.corsConfigurations.clear();
+ if (corsConfigurations != null) {
+ this.corsConfigurations.putAll(corsConfigurations);
+ }
+ }
+
+ /**
+ * Get the CORS configuration.
+ */
+ public Map getCorsConfigurations() {
+ return Collections.unmodifiableMap(this.corsConfigurations);
+ }
+
+ /**
+ * Register a {@link CorsConfiguration} for the specified path pattern.
+ */
+ public void registerCorsConfiguration(String path, CorsConfiguration config) {
+ this.corsConfigurations.put(path, config);
+ }
+
+
+ @Override
+ public CorsConfiguration getCorsConfiguration(ServerWebExchange exchange) {
+ String lookupPath = this.pathHelper.getLookupPathForRequest(exchange);
+ for (Map.Entry entry : this.corsConfigurations.entrySet()) {
+ if (this.pathMatcher.match(entry.getKey(), lookupPath)) {
+ return entry.getValue();
+ }
+ }
+ return null;
+ }
+
+}
diff --git a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java
index 98596580adf..09495d39777 100644
--- a/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java
+++ b/spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java
@@ -710,8 +710,8 @@ public class UriComponentsBuilder implements Cloneable {
}
}
- if ((this.scheme.equals("http") && "80".equals(this.port)) ||
- (this.scheme.equals("https") && "443".equals(this.port))) {
+ if ((this.scheme != null) && ((this.scheme.equals("http") && "80".equals(this.port)) ||
+ (this.scheme.equals("https") && "443".equals(this.port)))) {
this.port = null;
}
diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java
new file mode 100644
index 00000000000..a983350ed06
--- /dev/null
+++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/CorsUtilsTests.java
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2002-2015 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.cors.reactive;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import org.junit.Test;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
+import org.springframework.web.cors.reactive.CorsUtils;
+
+/**
+ * Test case for reactive {@link CorsUtils}.
+ *
+ * @author Sebastien Deleuze
+ */
+public class CorsUtilsTests {
+
+ @Test
+ public void isCorsRequest() {
+ MockServerHttpRequest request = new MockServerHttpRequest();
+ request.addHeader(HttpHeaders.ORIGIN, "http://domain.com");
+ assertTrue(CorsUtils.isCorsRequest(request));
+ }
+
+ @Test
+ public void isNotCorsRequest() {
+ MockServerHttpRequest request = new MockServerHttpRequest();
+ assertFalse(CorsUtils.isCorsRequest(request));
+ }
+
+ @Test
+ public void isPreFlightRequest() {
+ MockServerHttpRequest request = new MockServerHttpRequest();
+ request.setHttpMethod(HttpMethod.OPTIONS);
+ request.addHeader(HttpHeaders.ORIGIN, "http://domain.com");
+ request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ assertTrue(CorsUtils.isPreFlightRequest(request));
+ }
+
+ @Test
+ public void isNotPreFlightRequest() {
+ MockServerHttpRequest request = new MockServerHttpRequest();
+ assertFalse(CorsUtils.isPreFlightRequest(request));
+
+ request = new MockServerHttpRequest();
+ request.setHttpMethod(HttpMethod.OPTIONS);
+ request.addHeader(HttpHeaders.ORIGIN, "http://domain.com");
+ assertFalse(CorsUtils.isPreFlightRequest(request));
+
+ request = new MockServerHttpRequest();
+ request.setHttpMethod(HttpMethod.OPTIONS);
+ request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ assertFalse(CorsUtils.isPreFlightRequest(request));
+ }
+
+}
diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java
new file mode 100644
index 00000000000..a2cb70958d4
--- /dev/null
+++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/DefaultCorsProcessorTests.java
@@ -0,0 +1,351 @@
+/*
+ * Copyright 2002-2016 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.cors.reactive;
+
+import static org.junit.Assert.*;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.HttpStatus;
+import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
+import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
+import org.springframework.web.cors.CorsConfiguration;
+import org.springframework.web.cors.reactive.DefaultCorsProcessor;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.adapter.DefaultServerWebExchange;
+import org.springframework.web.server.session.MockWebSessionManager;
+
+/**
+ * Test reactive {@link DefaultCorsProcessor} with simple or preflight CORS request.
+ *
+ * @author Sebastien Deleuze
+ * @author Rossen Stoyanchev
+ * @author Juergen Hoeller
+ */
+public class DefaultCorsProcessorTests {
+
+ private MockServerHttpRequest request;
+
+ private MockServerHttpResponse response;
+
+ private ServerWebExchange exchange;
+
+ private DefaultCorsProcessor processor;
+
+ private CorsConfiguration conf;
+
+
+ @Before
+ public void setup() {
+ this.request = new MockServerHttpRequest();
+ this.request.setUri("http://localhost/test.html");
+ this.conf = new CorsConfiguration();
+ this.response = new MockServerHttpResponse();
+ this.response.setStatusCode(HttpStatus.OK);
+ this.processor = new DefaultCorsProcessor();
+ this.exchange = new DefaultServerWebExchange(this.request, this.response, new MockWebSessionManager());
+ }
+
+
+ @Test
+ public void actualRequestWithOriginHeader() throws Exception {
+ this.request.setHttpMethod(HttpMethod.GET);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals(HttpStatus.FORBIDDEN, this.response.getStatusCode());
+ }
+
+ @Test
+ public void actualRequestWithOriginHeaderAndNullConfig() throws Exception {
+ this.request.setHttpMethod(HttpMethod.GET);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+
+ this.processor.processRequest(null, this.exchange);
+ assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ }
+
+ @Test
+ public void actualRequestWithOriginHeaderAndAllowedOrigin() throws Exception {
+ this.request.setHttpMethod(HttpMethod.GET);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.conf.addAllowedOrigin("*");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals("*", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_MAX_AGE));
+ assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS));
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ }
+
+ @Test
+ public void actualRequestCredentials() throws Exception {
+ this.request.setHttpMethod(HttpMethod.GET);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.conf.addAllowedOrigin("http://domain1.com");
+ this.conf.addAllowedOrigin("http://domain2.com");
+ this.conf.addAllowedOrigin("http://domain3.com");
+ this.conf.setAllowCredentials(true);
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals("http://domain2.com", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
+ assertEquals("true", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ }
+
+ @Test
+ public void actualRequestCredentialsWithOriginWildcard() throws Exception {
+ this.request.setHttpMethod(HttpMethod.GET);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.conf.addAllowedOrigin("*");
+ this.conf.setAllowCredentials(true);
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals("http://domain2.com", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
+ assertEquals("true", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ }
+
+ @Test
+ public void actualRequestCaseInsensitiveOriginMatch() throws Exception {
+ this.request.setHttpMethod(HttpMethod.GET);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.conf.addAllowedOrigin("http://DOMAIN2.com");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ }
+
+ @Test
+ public void actualRequestExposedHeaders() throws Exception {
+ this.request.setHttpMethod(HttpMethod.GET);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.conf.addExposedHeader("header1");
+ this.conf.addExposedHeader("header2");
+ this.conf.addAllowedOrigin("http://domain2.com");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals("http://domain2.com", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS));
+ assertTrue(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header1"));
+ assertTrue(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_EXPOSE_HEADERS).contains("header2"));
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ }
+
+ @Test
+ public void preflightRequestAllOriginsAllowed() throws Exception {
+ this.request.setHttpMethod(HttpMethod.OPTIONS);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ this.conf.addAllowedOrigin("*");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ }
+
+ @Test
+ public void preflightRequestWrongAllowedMethod() throws Exception {
+ this.request.setHttpMethod(HttpMethod.OPTIONS);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "DELETE");
+ this.conf.addAllowedOrigin("*");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertEquals(HttpStatus.FORBIDDEN, this.response.getStatusCode());
+ }
+
+ @Test
+ public void preflightRequestMatchedAllowedMethod() throws Exception {
+ this.request.setHttpMethod(HttpMethod.OPTIONS);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ this.conf.addAllowedOrigin("*");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ assertEquals("GET,HEAD", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS));
+ }
+
+ @Test
+ public void preflightRequestTestWithOriginButWithoutOtherHeaders() throws Exception {
+ this.request.setHttpMethod(HttpMethod.OPTIONS);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals(HttpStatus.FORBIDDEN, this.response.getStatusCode());
+ }
+
+ @Test
+ public void preflightRequestWithoutRequestMethod() throws Exception {
+ this.request.setHttpMethod(HttpMethod.OPTIONS);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals(HttpStatus.FORBIDDEN, this.response.getStatusCode());
+ }
+
+ @Test
+ public void preflightRequestWithRequestAndMethodHeaderButNoConfig() throws Exception {
+ this.request.setHttpMethod(HttpMethod.OPTIONS);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals(HttpStatus.FORBIDDEN, this.response.getStatusCode());
+ }
+
+ @Test
+ public void preflightRequestValidRequestAndConfig() throws Exception {
+ this.request.setHttpMethod(HttpMethod.OPTIONS);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1");
+ this.conf.addAllowedOrigin("*");
+ this.conf.addAllowedMethod("GET");
+ this.conf.addAllowedMethod("PUT");
+ this.conf.addAllowedHeader("header1");
+ this.conf.addAllowedHeader("header2");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals("*", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS));
+ assertEquals("GET,PUT", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS));
+ assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_MAX_AGE));
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ }
+
+ @Test
+ public void preflightRequestCredentials() throws Exception {
+ this.request.setHttpMethod(HttpMethod.OPTIONS);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1");
+ this.conf.addAllowedOrigin("http://domain1.com");
+ this.conf.addAllowedOrigin("http://domain2.com");
+ this.conf.addAllowedOrigin("http://domain3.com");
+ this.conf.addAllowedHeader("Header1");
+ this.conf.setAllowCredentials(true);
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals("http://domain2.com", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
+ assertEquals("true", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS));
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ }
+
+ @Test
+ public void preflightRequestCredentialsWithOriginWildcard() throws Exception {
+ this.request.setHttpMethod(HttpMethod.OPTIONS);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1");
+ this.conf.addAllowedOrigin("http://domain1.com");
+ this.conf.addAllowedOrigin("*");
+ this.conf.addAllowedOrigin("http://domain3.com");
+ this.conf.addAllowedHeader("Header1");
+ this.conf.setAllowCredentials(true);
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals("http://domain2.com", this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ }
+
+ @Test
+ public void preflightRequestAllowedHeaders() throws Exception {
+ this.request.setHttpMethod(HttpMethod.OPTIONS);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1, Header2");
+ this.conf.addAllowedHeader("Header1");
+ this.conf.addAllowedHeader("Header2");
+ this.conf.addAllowedHeader("Header3");
+ this.conf.addAllowedOrigin("http://domain2.com");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS));
+ assertTrue(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1"));
+ assertTrue(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2"));
+ assertFalse(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header3"));
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ }
+
+ @Test
+ public void preflightRequestAllowsAllHeaders() throws Exception {
+ this.request.setHttpMethod(HttpMethod.OPTIONS);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "Header1, Header2");
+ this.conf.addAllowedHeader("*");
+ this.conf.addAllowedOrigin("http://domain2.com");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS));
+ assertTrue(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header1"));
+ assertTrue(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("Header2"));
+ assertFalse(this.response.getHeaders().getFirst(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS).contains("*"));
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ }
+
+ @Test
+ public void preflightRequestWithEmptyHeaders() throws Exception {
+ this.request.setHttpMethod(HttpMethod.OPTIONS);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS, "");
+ this.conf.addAllowedHeader("*");
+ this.conf.addAllowedOrigin("http://domain2.com");
+
+ this.processor.processRequest(this.conf, this.exchange);
+ assertTrue(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS));
+ assertEquals(HttpStatus.OK, this.response.getStatusCode());
+ }
+
+ @Test
+ public void preflightRequestWithNullConfig() throws Exception {
+ this.request.setHttpMethod(HttpMethod.OPTIONS);
+ this.request.addHeader(HttpHeaders.ORIGIN, "http://domain2.com");
+ this.request.addHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD, "GET");
+ this.conf.addAllowedOrigin("*");
+
+ this.processor.processRequest(null, this.exchange);
+ assertFalse(this.response.getHeaders().containsKey(HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN));
+ assertEquals(HttpStatus.FORBIDDEN, this.response.getStatusCode());
+ }
+
+}
diff --git a/spring-web/src/test/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSourceTests.java b/spring-web/src/test/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSourceTests.java
new file mode 100644
index 00000000000..f0be0bfaded
--- /dev/null
+++ b/spring-web/src/test/java/org/springframework/web/cors/reactive/UrlBasedCorsConfigurationSourceTests.java
@@ -0,0 +1,68 @@
+/*
+ * Copyright 2002-2015 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.cors.reactive;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import org.junit.Test;
+
+import org.springframework.http.HttpMethod;
+import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
+import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse;
+import org.springframework.web.cors.CorsConfiguration;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.adapter.DefaultServerWebExchange;
+import org.springframework.web.server.session.MockWebSessionManager;
+
+/**
+ * Unit tests for reactive {@link UrlBasedCorsConfigurationSource}.
+ * @author Sebastien Deleuze
+ */
+public class UrlBasedCorsConfigurationSourceTests {
+
+ private final UrlBasedCorsConfigurationSource configSource = new UrlBasedCorsConfigurationSource();
+
+ @Test
+ public void empty() {
+ ServerHttpRequest request = new MockServerHttpRequest(HttpMethod.GET, "/bar/test.html");
+ ServerWebExchange exchange = new DefaultServerWebExchange(request,
+ new MockServerHttpResponse(), new MockWebSessionManager());
+ assertNull(this.configSource.getCorsConfiguration(exchange));
+ }
+
+ @Test
+ public void registerAndMatch() {
+ CorsConfiguration config = new CorsConfiguration();
+ this.configSource.registerCorsConfiguration("/bar/**", config);
+ assertNull(this.configSource.getCorsConfiguration(
+ new DefaultServerWebExchange(
+ new MockServerHttpRequest(HttpMethod.GET, "/foo/test.html"),
+ new MockServerHttpResponse(),
+ new MockWebSessionManager())));
+ assertEquals(config, this.configSource.getCorsConfiguration(new DefaultServerWebExchange(
+ new MockServerHttpRequest(HttpMethod.GET, "/bar/test.html"),
+ new MockServerHttpResponse(),
+ new MockWebSessionManager())));
+ }
+
+ @Test(expected = UnsupportedOperationException.class)
+ public void unmodifiableConfigurationsMap() {
+ this.configSource.getCorsConfigurations().put("/**", new CorsConfiguration());
+ }
+
+}