From 42621086fecc597d88b32016e0da07141332fd92 Mon Sep 17 00:00:00 2001 From: Jens Schauder Date: Thu, 4 May 2017 11:23:04 +0200 Subject: [PATCH] DATACMNS-1026 - ExtensionAwareEvaluationContextProvider now returns all overloaded methods as functions. All overloaded methods are now available in SPeL expressions. Among methods with identical argument list from different sources in the same extension (extension, root object, aliases) the last one in the order in parens wins. If there is more than one method for an application the following rules are applied: if there is one method with exact matching types in the argument list it is used, otherwise an exception is thrown. Original pull request: #217. --- ...EvaluationContextExtensionInformation.java | 35 ++--- ...tensionAwareEvaluationContextProvider.java | 20 ++- .../data/repository/query/Functions.java | 133 ++++++++++++++++++ .../data/repository/query/spi/Function.java | 49 ++++++- .../data/util/MultiValueMapCollector.java | 97 +++++++++++++ .../data/util/StreamUtils.java | 13 ++ .../ConvertingPropertyAccessorUnitTests.java | 3 +- ...areEvaluationContextProviderUnitTests.java | 122 +++++++++++++++- 8 files changed, 441 insertions(+), 31 deletions(-) create mode 100644 src/main/java/org/springframework/data/repository/query/Functions.java create mode 100644 src/main/java/org/springframework/data/util/MultiValueMapCollector.java diff --git a/src/main/java/org/springframework/data/repository/query/EvaluationContextExtensionInformation.java b/src/main/java/org/springframework/data/repository/query/EvaluationContextExtensionInformation.java index 7ff54876b..5e77903ee 100644 --- a/src/main/java/org/springframework/data/repository/query/EvaluationContextExtensionInformation.java +++ b/src/main/java/org/springframework/data/repository/query/EvaluationContextExtensionInformation.java @@ -29,14 +29,17 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Optional; -import java.util.stream.Collectors; import org.springframework.beans.BeanUtils; import org.springframework.data.repository.query.EvaluationContextExtensionInformation.ExtensionTypeInformation.PublicMethodAndFieldFilter; +import org.springframework.data.repository.query.Functions.NameAndArgumentCount; import org.springframework.data.repository.query.spi.EvaluationContextExtension; import org.springframework.data.repository.query.spi.Function; +import org.springframework.data.util.MultiValueMapCollector; import org.springframework.data.util.Streamable; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils.FieldFilter; import org.springframework.util.ReflectionUtils.MethodFilter; @@ -126,7 +129,7 @@ class EvaluationContextExtensionInformation { * * @return the functions will never be {@literal null}. */ - private final Map functions; + private final MultiValueMap functions; /** * Creates a new {@link ExtensionTypeInformation} fir the given type. @@ -141,15 +144,15 @@ class EvaluationContextExtensionInformation { this.properties = discoverDeclaredProperties(type); } - private static Map discoverDeclaredFunctions(Class type) { + private static MultiValueMap discoverDeclaredFunctions(Class type) { - Map map = new HashMap<>(); + MultiValueMap map = CollectionUtils.toMultiValueMap(new HashMap<>()); ReflectionUtils.doWithMethods(type, // - method -> map.put(method.getName(), new Function(method, null)), // + method -> map.add(NameAndArgumentCount.of(method), new Function(method, null)), // PublicMethodAndFieldFilter.STATIC); - return map.isEmpty() ? Collections.emptyMap() : Collections.unmodifiableMap(map); + return CollectionUtils.unmodifiableMultiValueMap(map); } @RequiredArgsConstructor @@ -235,8 +238,7 @@ class EvaluationContextExtensionInformation { }, PublicMethodAndFieldFilter.NON_STATIC); - ReflectionUtils.doWithFields(type, RootObjectInformation.this.fields::add, - PublicMethodAndFieldFilter.NON_STATIC); + ReflectionUtils.doWithFields(type, RootObjectInformation.this.fields::add, PublicMethodAndFieldFilter.NON_STATIC); } /** @@ -245,14 +247,15 @@ class EvaluationContextExtensionInformation { * @param target can be {@literal null}. * @return the methods */ - public Map getFunctions(Optional target) { - - return target.map(it -> methods.stream()// - .collect(Collectors.toMap(// - Method::getName, // - method -> new Function(method, it), // - (left, right) -> right))) - .orElseGet(Collections::emptyMap); + public MultiValueMap getFunctions(Optional target) { + + return target.map( // + it -> methods.stream().collect( // + new MultiValueMapCollector<>( // + m -> NameAndArgumentCount.of(m), // + m -> new Function(m, it) // + ))) // + .orElseGet(() -> CollectionUtils.toMultiValueMap(Collections.emptyMap())); } /** diff --git a/src/main/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProvider.java b/src/main/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProvider.java index cb4fb1076..2b7615004 100644 --- a/src/main/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProvider.java +++ b/src/main/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProvider.java @@ -23,7 +23,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import java.util.stream.Collectors; @@ -60,6 +59,7 @@ import org.springframework.util.StringUtils; * @author Thomas Darimont * @author Oliver Gierke * @author Christoph Strobl + * @author Jens Schauder * @since 1.9 */ public class ExtensionAwareEvaluationContextProvider implements EvaluationContextProvider, ApplicationContextAware { @@ -316,10 +316,7 @@ public class ExtensionAwareEvaluationContextProvider implements EvaluationContex */ private Optional getMethodExecutor(EvaluationContextExtensionAdapter adapter, String name, List argumentTypes) { - - return adapter.getFunctions().entrySet().stream()// - .filter(entry -> entry.getKey().equals(name))// - .findFirst().map(Entry::getValue).map(FunctionMethodExecutor::new); + return adapter.getFunctions().get(name, argumentTypes).map(FunctionMethodExecutor::new); } /** @@ -388,7 +385,7 @@ public class ExtensionAwareEvaluationContextProvider implements EvaluationContex private final EvaluationContextExtension extension; - private final Map functions; + private final Functions functions = new Functions(); private final Map properties; /** @@ -401,17 +398,16 @@ public class ExtensionAwareEvaluationContextProvider implements EvaluationContex public EvaluationContextExtensionAdapter(EvaluationContextExtension extension, EvaluationContextExtensionInformation information) { - Assert.notNull(extension, "Extenstion must not be null!"); + Assert.notNull(extension, "Extension must not be null!"); Assert.notNull(information, "Extension information must not be null!"); Optional target = Optional.ofNullable(extension.getRootObject()); ExtensionTypeInformation extensionTypeInformation = information.getExtensionTypeInformation(); RootObjectInformation rootObjectInformation = information.getRootObjectInformation(target); - this.functions = new HashMap<>(); - this.functions.putAll(extensionTypeInformation.getFunctions()); - this.functions.putAll(rootObjectInformation.getFunctions(target)); - this.functions.putAll(extension.getFunctions()); + functions.addAll(extension.getFunctions()); + functions.addAll(rootObjectInformation.getFunctions(target)); + functions.addAll(extensionTypeInformation.getFunctions()); this.properties = new HashMap<>(); this.properties.putAll(extensionTypeInformation.getProperties()); @@ -435,7 +431,7 @@ public class ExtensionAwareEvaluationContextProvider implements EvaluationContex * * @return */ - public Map getFunctions() { + Functions getFunctions() { return this.functions; } diff --git a/src/main/java/org/springframework/data/repository/query/Functions.java b/src/main/java/org/springframework/data/repository/query/Functions.java new file mode 100644 index 000000000..2c0244619 --- /dev/null +++ b/src/main/java/org/springframework/data/repository/query/Functions.java @@ -0,0 +1,133 @@ +/* + * Copyright 2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.query; + +import lombok.AllArgsConstructor; +import lombok.Value; + +import java.lang.reflect.Method; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.data.repository.query.spi.Function; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; + +/** + * {@link MultiValueMap} like datastructure to keep lists of + * {@link org.springframework.data.repository.query.spi.Function}s indexed by name and argument list length, where the + * value lists are actually unique with respect to the signature. + * + * @author Jens Schauder + * @since 2.0 + */ +class Functions { + + private final MultiValueMap functions = CollectionUtils + .toMultiValueMap(new HashMap<>()); + + void addAll(Map newFunctions) { + + newFunctions.forEach((n, f) -> { + NameAndArgumentCount k = new NameAndArgumentCount(n, f.getParameterCount()); + List currentElements = get(k); + if (!contains(currentElements, f)) { + functions.add(k, f); + } + }); + } + + void addAll(MultiValueMap newFunctions) { + + newFunctions.forEach((k, list) -> { + List currentElements = get(k); + list.stream() // + .filter(f -> !contains(currentElements, f)) // + .forEach(f -> functions.add(k, f)); + }); + } + + List get(NameAndArgumentCount key) { + return functions.getOrDefault(key, Collections.emptyList()); + } + + /** + * Gets the function that best matches the parameters given. The {@code name} must match, and the + * {@code argumentTypes} must be compatible with parameter list of the function. In order to resolve ambiguity it + * checks for a method with exactly matching parameter list. + * + * @param name the name of the method + * @param argumentTypes types of arguments that the method must be able to accept + * @return a {@code Function} if a unique on gets found. {@code Optional.empty} if none matches. Throws + * {@link IllegalStateException} if multiple functions match the parameters. + */ + Optional get(String name, List argumentTypes) { + + Stream candidates = get(new NameAndArgumentCount(name, argumentTypes.size())).stream() // + .filter(f -> f.supports(argumentTypes)); + return bestMatch(candidates.collect(Collectors.toList()), argumentTypes); + } + + private static boolean contains(List elements, Function f) { + return elements.stream().anyMatch(f::isSignatureEqual); + } + + private static Optional bestMatch(List candidates, List argumentTypes) { + + if (candidates.isEmpty()) { + return Optional.empty(); + } + if (candidates.size() == 1) { + return Optional.of(candidates.get(0)); + } + + Optional exactMatch = candidates.stream().filter(f -> f.supportsExact(argumentTypes)).findFirst(); + if (!exactMatch.isPresent()) { + throw new IllegalStateException(createErrorMessage(candidates, argumentTypes)); + } + + return exactMatch; + } + + private static String createErrorMessage(List candidates, List argumentTypes) { + + String argumentTypeString = String.join( // + ",", // + argumentTypes.stream().map(TypeDescriptor::getName).collect(Collectors.toList())); + + String messageTemplate = "There are multiple matching methods of name '%s' for parameter types (%s), but no " + + "exact match. Make sure to provide only one matching overload or one with exactly those types."; + + return String.format(messageTemplate, candidates.get(0).getName(), argumentTypeString); + } + + @Value + @AllArgsConstructor + static class NameAndArgumentCount { + String name; + int count; + + static NameAndArgumentCount of(Method m) { + return new NameAndArgumentCount(m.getName(), m.getParameterCount()); + } + } +} diff --git a/src/main/java/org/springframework/data/repository/query/spi/Function.java b/src/main/java/org/springframework/data/repository/query/spi/Function.java index 84f8ff676..97b3006a5 100644 --- a/src/main/java/org/springframework/data/repository/query/spi/Function.java +++ b/src/main/java/org/springframework/data/repository/query/spi/Function.java @@ -1,5 +1,5 @@ /* - * Copyright 2014 the original author or authors. + * Copyright 2014-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.data.repository.query.spi; import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.util.Arrays; import java.util.List; import org.springframework.core.convert.TypeDescriptor; @@ -29,6 +30,7 @@ import org.springframework.util.TypeUtils; * * @author Thomas Darimont * @author Oliver Gierke + * @author Jens Schauder * @since 1.9 */ public class Function { @@ -115,4 +117,49 @@ public class Function { return true; } + + /** + * Returns the number of parameters required by the underlying method. + * + * @return + */ + public int getParameterCount() { + return method.getParameterCount(); + } + + /** + * Checks if the encapsulated method has exactly the argument types as those passed as an argument. + * + * @param argumentTypes a list of {@link TypeDescriptor}s to compare with the argument types of the method + * @return {@code true} if the types are equal, {@code false} otherwise. + */ + public boolean supportsExact(List argumentTypes) { + + if (method.getParameterCount() != argumentTypes.size()) { + return false; + } + + Class[] parameterTypes = method.getParameterTypes(); + + for (int i = 0; i < parameterTypes.length; i++) { + if (parameterTypes[i] != argumentTypes.get(i).getType()) { + return false; + } + } + + return true; + } + + /** + * Checks wether this {@code Function} has the same signature as another {@code Function}. + * + * @param other the {@code Function} to compare {@code this} with. + * + * @return {@code true} iff name and argument list are the same. + */ + public boolean isSignatureEqual(Function other) { + + return getName().equals(other.getName()) // + && Arrays.equals(method.getParameterTypes(), other.method.getParameterTypes()); + } } diff --git a/src/main/java/org/springframework/data/util/MultiValueMapCollector.java b/src/main/java/org/springframework/data/util/MultiValueMapCollector.java new file mode 100644 index 000000000..597d82911 --- /dev/null +++ b/src/main/java/org/springframework/data/util/MultiValueMapCollector.java @@ -0,0 +1,97 @@ +/* + * Copyright 2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.util; + +import lombok.NonNull; +import lombok.RequiredArgsConstructor; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.BinaryOperator; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collector; + +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; + +/** + * A {@link Collector} for building a {@link MultiValueMap} from a {@link java.util.stream.Stream}. + * + * @author Jens Schauder + * @since 2.0 + */ +@RequiredArgsConstructor +public class MultiValueMapCollector implements Collector, MultiValueMap> { + + @NonNull private final Function keyFunction; + @NonNull private final Function valueFunction; + + /* + * (non-Javadoc) + * @see java.util.stream.Collector#supplier() + */ + @Override + public Supplier> supplier() { + return () -> CollectionUtils.toMultiValueMap(new HashMap<>()); + } + + /* + * (non-Javadoc) + * @see java.util.stream.Collector#accumulator() + */ + @Override + public BiConsumer, T> accumulator() { + return (map, t) -> map.add(keyFunction.apply(t), valueFunction.apply(t)); + } + + /* + * (non-Javadoc) + * @see java.util.stream.Collector#combiner() + */ + @Override + public BinaryOperator> combiner() { + + return (map1, map2) -> { + + for (K key : map2.keySet()) { + map1.addAll(key, map2.get(key)); + } + + return map1; + }; + } + + /* + * (non-Javadoc) + * @see java.util.stream.Collector#finisher() + */ + @Override + public Function, MultiValueMap> finisher() { + return Function.identity(); + } + + /* + * (non-Javadoc) + * @see java.util.stream.Collector#characteristics() + */ + @Override + public Set characteristics() { + return EnumSet.of(Characteristics.IDENTITY_FINISH, Characteristics.UNORDERED); + } +} diff --git a/src/main/java/org/springframework/data/util/StreamUtils.java b/src/main/java/org/springframework/data/util/StreamUtils.java index ea84b2554..45eaa8499 100644 --- a/src/main/java/org/springframework/data/util/StreamUtils.java +++ b/src/main/java/org/springframework/data/util/StreamUtils.java @@ -23,11 +23,13 @@ import java.util.List; import java.util.Set; import java.util.Spliterator; import java.util.Spliterators; +import java.util.function.Function; import java.util.stream.Collector; import java.util.stream.Stream; import java.util.stream.StreamSupport; import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; /** * Spring Data specific Java {@link Stream} utility methods and classes. @@ -77,4 +79,15 @@ public interface StreamUtils { public static Collector> toUnmodifiableSet() { return collectingAndThen(toSet(), Collections::unmodifiableSet); } + + /** + * Returns a {@link Collector} to create a {@link MultiValueMap}. + * + * @param keyFunction {@link Function} to create a key from an element of the {@link java.util.stream.Stream} + * @param valueFunction {@link Function} to create a value from an element of the {@link java.util.stream.Stream} + */ + public static Collector, MultiValueMap> toMultiMap(Function keyFunction, + Function valueFunction) { + return new MultiValueMapCollector(keyFunction, valueFunction); + } } diff --git a/src/test/java/org/springframework/data/mapping/model/ConvertingPropertyAccessorUnitTests.java b/src/test/java/org/springframework/data/mapping/model/ConvertingPropertyAccessorUnitTests.java index ff809823f..6e45f8ce7 100755 --- a/src/test/java/org/springframework/data/mapping/model/ConvertingPropertyAccessorUnitTests.java +++ b/src/test/java/org/springframework/data/mapping/model/ConvertingPropertyAccessorUnitTests.java @@ -59,7 +59,8 @@ public class ConvertingPropertyAccessorUnitTests { Entity entity = new Entity(); entity.id = 1L; - assertThat(getIdProperty()).hasValueSatisfying(it -> assertThat(getAccessor(entity, CONVERSION_SERVICE).getProperty(it, String.class)).hasValue("1")); + assertThat(getIdProperty()).hasValueSatisfying( + it -> assertThat(getAccessor(entity, CONVERSION_SERVICE).getProperty(it, String.class)).hasValue("1")); } @Test // DATACMNS-596 diff --git a/src/test/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProviderUnitTests.java b/src/test/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProviderUnitTests.java index 53bc88e6c..e508222ba 100755 --- a/src/test/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProviderUnitTests.java +++ b/src/test/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProviderUnitTests.java @@ -19,6 +19,7 @@ import static org.assertj.core.api.Assertions.*; import lombok.RequiredArgsConstructor; +import java.io.Serializable; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; @@ -28,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; import org.junit.Before; import org.junit.Test; import org.springframework.data.domain.PageRequest; @@ -44,7 +46,8 @@ import org.springframework.expression.spel.standard.SpelExpressionParser; * Unit tests {@link ExtensionAwareEvaluationContextProvider}. * * @author Oliver Gierke - * @author Thomas Darimont. + * @author Thomas Darimont + * @author Jens Schauder */ public class ExtensionAwareEvaluationContextProviderUnitTests { @@ -219,6 +222,68 @@ public class ExtensionAwareEvaluationContextProviderUnitTests { assertThat(counter.get()).isEqualTo(2); } + @Test // DATACMNS-1026 + public void overloadedMethodsGetResolved() throws Exception { + + provider = createContextProviderWithOverloads(); + + // from the root object + assertThat(evaluateExpression("method()")).isEqualTo("zero"); + assertThat(evaluateExpression("method(23)")).isEqualTo("single-int"); + assertThat(evaluateExpression("method('hello')")).isEqualTo("single-string"); + assertThat(evaluateExpression("method('one', 'two')")).isEqualTo("two"); + + // from the extension + assertThat(evaluateExpression("method(1, 2)")).isEqualTo("two-ints"); + assertThat(evaluateExpression("method(1, 'two')")).isEqualTo("int-and-string"); + } + + @Test // DATACMNS-1026 + public void methodFromRootObjectOverwritesMethodFromExtension() throws Exception { + + provider = createContextProviderWithOverloads(); + + assertThat(evaluateExpression("ambiguous()")).isEqualTo("from-root"); + } + + @Test // DATACMNS-1026 + public void aliasedMethodOverwritesMethodFromRootObject() throws Exception { + + provider = createContextProviderWithOverloads(); + + assertThat(evaluateExpression("aliasedMethod()")).isEqualTo("methodResult"); + } + + @Test // DATACMNS-1026 + public void exactMatchIsPreferred() throws Exception { + + provider = createContextProviderWithOverloads(); + + assertThat(evaluateExpression("ambiguousOverloaded('aString')")).isEqualTo("string"); + } + + @Test // DATACMNS-1026 + public void throwsExceptionWhenStillAmbiguous() throws Exception { + + provider = createContextProviderWithOverloads(); + + assertThatExceptionOfType(IllegalStateException.class) // + .isThrownBy(() -> evaluateExpression("ambiguousOverloaded(23)")) // + .withMessageContaining("ambiguousOverloaded") // + .withMessageContaining("(java.lang.Integer)"); + } + + private ExtensionAwareEvaluationContextProvider createContextProviderWithOverloads() { + + return new ExtensionAwareEvaluationContextProvider(Collections.singletonList( // + new DummyExtension("_first", "first") { + @Override + public Object getRootObject() { + return new RootWithOverloads(); + } + })); + } + @RequiredArgsConstructor public static class DummyExtension extends EvaluationContextExtensionSupport { @@ -269,6 +334,22 @@ public class ExtensionAwareEvaluationContextProviderUnitTests { public static String extensionMethod() { return "methodResult"; } + + public static String method(int i1, int i2) { + return "two-ints"; + } + + public static String method(int i, String s) { + return "int-and-string"; + } + + public static String ambiguous() { + return "from-extension-type"; + } + + public static String ambiguousToo() { + return "from-extension-type"; + } } private Object evaluateExpression(String expression) { @@ -312,4 +393,43 @@ public class ExtensionAwareEvaluationContextProviderUnitTests { return "rootObjectInstanceMethod2"; } } + + public static class RootWithOverloads { + + public String method() { + return "zero"; + } + + public String method(String s) { + return "single-string"; + } + + public String method(int i) { + return "single-int"; + } + + public String method(String s1, String s2) { + return "two"; + } + + public String ambiguous() { + return "from-root"; + } + + public String aliasedMethod() { + return "from-root"; + } + + public String ambiguousOverloaded(String s) { + return "string"; + } + + public String ambiguousOverloaded(Object o) { + return "object"; + } + + public String ambiguousOverloaded(Serializable o) { + return "serializable"; + } + } }