diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ClientRequest.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ClientRequest.java index e89b7da9256..5240d033e79 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ClientRequest.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ClientRequest.java @@ -18,6 +18,7 @@ package org.springframework.web.reactive.function.client; import java.net.URI; import java.util.Map; +import java.util.Optional; import java.util.function.Consumer; import org.reactivestreams.Publisher; @@ -69,6 +70,22 @@ public interface ClientRequest { */ BodyInserter body(); + /** + * Return the request attribute value if present. + * @param name the attribute name + * @return the attribute value + */ + default Optional attribute(String name) { + Map attributes = attributes(); + if (attributes.containsKey(name)) { + return Optional.of(attributes.get(name)); + } + else { + return Optional.empty(); + } + } + + /** * Return the attributes of this request. */ diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctions.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctions.java index 315e626fa2d..0f8481f4c7c 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctions.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctions.java @@ -16,9 +16,10 @@ package org.springframework.web.reactive.function.client; +import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.Base64; -import java.util.Map; +import java.util.Optional; import java.util.function.Function; import reactor.core.publisher.Mono; @@ -60,57 +61,72 @@ public abstract class ExchangeFilterFunctions { Assert.notNull(username, "'username' must not be null"); Assert.notNull(password, "'password' must not be null"); - return basicAuthentication(r -> username, r -> password); + return basicAuthenticationInternal(r -> Optional.of(new Credentials(username, password))); } /** * Return a filter that adds an Authorization header for HTTP Basic Authentication, based on * the username and password provided in the - * {@linkplain ClientRequest#attributes() request attributes}. + * {@linkplain ClientRequest#attributes() request attributes}. If the attributes are not found, + * no authorization header * @return the {@link ExchangeFilterFunction} that adds the Authorization header * @see #USERNAME_ATTRIBUTE * @see #PASSWORD_ATTRIBUTE */ public static ExchangeFilterFunction basicAuthentication() { - return basicAuthentication( - request -> getRequiredAttribute(request, USERNAME_ATTRIBUTE), - request -> getRequiredAttribute(request, PASSWORD_ATTRIBUTE) - ); - } - - private static String getRequiredAttribute(ClientRequest request, String key) { - Map attributes = request.attributes(); - if (attributes.containsKey(key)) { - return (String) attributes.get(key); - } else { - throw new IllegalStateException( - "Could not find request attribute with key \"" + key + "\""); - } + return basicAuthenticationInternal( + request -> { + Optional username = request.attribute(USERNAME_ATTRIBUTE).map(o -> (String)o); + Optional password = request.attribute(PASSWORD_ATTRIBUTE).map(o -> (String)o); + if (username.isPresent() && password.isPresent()) { + return Optional.of(new Credentials(username.get(), password.get())); + } else { + return Optional.empty(); + } + }); } - private static ExchangeFilterFunction basicAuthentication(Function usernameFunction, - Function passwordFunction) { + private static ExchangeFilterFunction basicAuthenticationInternal( + Function> credentialsFunction) { return ExchangeFilterFunction.ofRequestProcessor( - clientRequest -> { - String authorization = authorization(usernameFunction.apply(clientRequest), - passwordFunction.apply(clientRequest)); - ClientRequest authorizedRequest = ClientRequest.from(clientRequest) - .headers(headers -> { - headers.set(HttpHeaders.AUTHORIZATION, authorization); - }) - .build(); - return Mono.just(authorizedRequest); - }); + clientRequest -> credentialsFunction.apply(clientRequest).map( + credentials -> { + ClientRequest authorizedRequest = ClientRequest.from(clientRequest) + .headers(headers -> { + headers.set(HttpHeaders.AUTHORIZATION, + authorization(credentials)); + }) + .build(); + return Mono.just(authorizedRequest); + }) + .orElse(Mono.just(clientRequest))); } - private static String authorization(String username, String password) { - String credentials = username + ":" + password; - byte[] credentialBytes = credentials.getBytes(StandardCharsets.ISO_8859_1); + private static String authorization(Credentials credentials) { + byte[] credentialBytes = credentials.toByteArray(StandardCharsets.ISO_8859_1); byte[] encodedBytes = Base64.getEncoder().encode(credentialBytes); String encodedCredentials = new String(encodedBytes, StandardCharsets.ISO_8859_1); return "Basic " + encodedCredentials; } + private static class Credentials { + + private String username; + + private String password; + + public Credentials(String username, String password) { + this.username = username; + this.password = password; + } + + public byte[] toByteArray(Charset charset) { + String credentials = this.username + ":" + this.password; + return credentials.getBytes(charset); + } + + } + } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctionsTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctionsTests.java index d26b7c15c06..3aaecbb0a72 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctionsTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctionsTests.java @@ -117,4 +117,20 @@ public class ExchangeFilterFunctionsTests { assertEquals(response, result); } + @Test + public void basicAuthenticationAbsentAttributes() throws Exception { + ClientRequest request = ClientRequest.method(GET, URI.create("http://example.com")).build(); + ClientResponse response = mock(ClientResponse.class); + + ExchangeFunction exchange = r -> { + assertFalse(r.headers().containsKey(HttpHeaders.AUTHORIZATION)); + return Mono.just(response); + }; + + ExchangeFilterFunction auth = ExchangeFilterFunctions.basicAuthentication(); + assertFalse(request.headers().containsKey(HttpHeaders.AUTHORIZATION)); + ClientResponse result = auth.filter(request, exchange).block(); + assertEquals(response, result); + } + }