From 8032baa29693f47898ded4e96272e13847588ee4 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Sun, 29 Oct 2017 17:47:51 -0500 Subject: [PATCH] Polish InMemoryClientRegistrationRepository - use Map.get - Construct with stream() - Add tests - Remove unnecessary unmodifiableCollection (already unmodifiable) Fixes gh-4745 --- .../registration/ClientRegistration.java | 15 +++ .../InMemoryClientRegistrationRepository.java | 28 +++--- ...moryClientRegistrationRepositoryTests.java | 94 +++++++++++++++++++ 3 files changed, 122 insertions(+), 15 deletions(-) create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepositoryTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java index a7f5aa45db..4db3d2152a 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java @@ -82,6 +82,21 @@ public final class ClientRegistration { return this.clientName; } + @Override + public String toString() { + return "ClientRegistration{" + + "registrationId='" + this.registrationId + '\'' + + ", clientId='" + this.clientId + '\'' + + ", clientSecret='" + this.clientSecret + '\'' + + ", clientAuthenticationMethod=" + this.clientAuthenticationMethod + + ", authorizationGrantType=" + this.authorizationGrantType + + ", redirectUri='" + this.redirectUri + '\'' + + ", scopes=" + this.scopes + + ", providerDetails=" + this.providerDetails + + ", clientName='" + this.clientName + + '\'' + '}'; + } + public class ProviderDetails { private String authorizationUri; private String tokenUri; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java index fa2ddfb00b..8cea2b91b5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java @@ -21,7 +21,13 @@ import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Function; +import java.util.stream.Collector; + +import static java.util.stream.Collectors.collectingAndThen; +import static java.util.stream.Collectors.toConcurrentMap; +import static java.util.stream.Collectors.toMap; /** * A {@link ClientRegistrationRepository} that stores {@link ClientRegistration}(s) in-memory. @@ -36,28 +42,20 @@ public final class InMemoryClientRegistrationRepository implements ClientRegistr public InMemoryClientRegistrationRepository(List registrations) { Assert.notEmpty(registrations, "registrations cannot be empty"); - Map registrationsMap = new ConcurrentHashMap<>(); - registrations.forEach(registration -> { - if (registrationsMap.containsKey(registration.getRegistrationId())) { - throw new IllegalArgumentException("ClientRegistration must be unique. Found duplicate registrationId: " + - registration.getRegistrationId()); - } - registrationsMap.put(registration.getRegistrationId(), registration); - }); - this.registrations = Collections.unmodifiableMap(registrationsMap); + Collector> collector = + toConcurrentMap(ClientRegistration::getRegistrationId, Function.identity()); + this.registrations = registrations.stream() + .collect(collectingAndThen(collector, Collections::unmodifiableMap)); } @Override public ClientRegistration findByRegistrationId(String registrationId) { Assert.hasText(registrationId, "registrationId cannot be empty"); - return this.registrations.values().stream() - .filter(registration -> registration.getRegistrationId().equals(registrationId)) - .findFirst() - .orElse(null); + return this.registrations.get(registrationId); } @Override public Iterator iterator() { - return Collections.unmodifiableCollection(this.registrations.values()).iterator(); + return this.registrations.values().iterator(); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepositoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepositoryTests.java new file mode 100644 index 0000000000..9df4fc5edd --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepositoryTests.java @@ -0,0 +1,94 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.registration; + +import org.junit.Test; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.assertj.core.api.Assertions.*; + +/** + * @author Rob Winch + * @since 5.0 + */ +public class InMemoryClientRegistrationRepositoryTests { + private ClientRegistration registration = ClientRegistration.withRegistrationId("id") + .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) + .authorizationUri("https://example.com/oauth2/authorize") + .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) + .clientId("client-id") + .clientName("client-name") + .clientSecret("client-secret") + .redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}") + .scope("user") + .tokenUri("https://example.com/oauth/access_token") + .build(); + + private InMemoryClientRegistrationRepository clients = new InMemoryClientRegistrationRepository( + Arrays.asList(this.registration)); + + @Test(expected = IllegalArgumentException.class) + public void constructorListClientRegistrationWhenNullThenIllegalArgumentException() { + List registrations = null; + new InMemoryClientRegistrationRepository(registrations); + } + + @Test(expected = IllegalArgumentException.class) + public void constructorListClientRegistrationWhenEmptyThenIllegalArgumentException() { + List registrations = Collections.emptyList(); + new InMemoryClientRegistrationRepository(registrations); + } + + @Test(expected = IllegalStateException.class) + public void constructorListClientRegistrationWhenDuplicateIdThenIllegalArgumentException() { + List registrations = Arrays.asList(this.registration, this.registration); + new InMemoryClientRegistrationRepository(registrations); + } + + @Test + public void findByRegistrationIdWhenFoundThenFound() { + String id = this.registration.getRegistrationId(); + assertThat(this.clients.findByRegistrationId(id)).isEqualTo(this.registration); + } + + @Test + public void findByRegistrationIdWhenNotFoundThenNull() { + String id = this.registration.getRegistrationId() + "MISSING"; + assertThat(this.clients.findByRegistrationId(id)).isNull(); + } + + @Test(expected = IllegalArgumentException.class) + public void findByRegistrationIdWhenNullIdThenIllegalArgumentException() { + String id = null; + assertThat(this.clients.findByRegistrationId(id)); + } + + @Test(expected = UnsupportedOperationException.class) + public void iteratorWhenRemoveThenThrowsUnsupportedOperationException() { + this.clients.iterator().remove(); + } + + @Test + public void iteratorWhenGetThenContainsAll() { + assertThat(this.clients.iterator()).containsOnly(this.registration); + } +}