diff --git a/src/main/java/org/springframework/data/repository/query/EvaluationContextExtensionInformation.java b/src/main/java/org/springframework/data/repository/query/EvaluationContextExtensionInformation.java index 7ff54876b..5e77903ee 100644 --- a/src/main/java/org/springframework/data/repository/query/EvaluationContextExtensionInformation.java +++ b/src/main/java/org/springframework/data/repository/query/EvaluationContextExtensionInformation.java @@ -29,14 +29,17 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Optional; -import java.util.stream.Collectors; import org.springframework.beans.BeanUtils; import org.springframework.data.repository.query.EvaluationContextExtensionInformation.ExtensionTypeInformation.PublicMethodAndFieldFilter; +import org.springframework.data.repository.query.Functions.NameAndArgumentCount; import org.springframework.data.repository.query.spi.EvaluationContextExtension; import org.springframework.data.repository.query.spi.Function; +import org.springframework.data.util.MultiValueMapCollector; import org.springframework.data.util.Streamable; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; import org.springframework.util.ReflectionUtils; import org.springframework.util.ReflectionUtils.FieldFilter; import org.springframework.util.ReflectionUtils.MethodFilter; @@ -126,7 +129,7 @@ class EvaluationContextExtensionInformation { * * @return the functions will never be {@literal null}. */ - private final Map functions; + private final MultiValueMap functions; /** * Creates a new {@link ExtensionTypeInformation} fir the given type. @@ -141,15 +144,15 @@ class EvaluationContextExtensionInformation { this.properties = discoverDeclaredProperties(type); } - private static Map discoverDeclaredFunctions(Class type) { + private static MultiValueMap discoverDeclaredFunctions(Class type) { - Map map = new HashMap<>(); + MultiValueMap map = CollectionUtils.toMultiValueMap(new HashMap<>()); ReflectionUtils.doWithMethods(type, // - method -> map.put(method.getName(), new Function(method, null)), // + method -> map.add(NameAndArgumentCount.of(method), new Function(method, null)), // PublicMethodAndFieldFilter.STATIC); - return map.isEmpty() ? Collections.emptyMap() : Collections.unmodifiableMap(map); + return CollectionUtils.unmodifiableMultiValueMap(map); } @RequiredArgsConstructor @@ -235,8 +238,7 @@ class EvaluationContextExtensionInformation { }, PublicMethodAndFieldFilter.NON_STATIC); - ReflectionUtils.doWithFields(type, RootObjectInformation.this.fields::add, - PublicMethodAndFieldFilter.NON_STATIC); + ReflectionUtils.doWithFields(type, RootObjectInformation.this.fields::add, PublicMethodAndFieldFilter.NON_STATIC); } /** @@ -245,14 +247,15 @@ class EvaluationContextExtensionInformation { * @param target can be {@literal null}. * @return the methods */ - public Map getFunctions(Optional target) { - - return target.map(it -> methods.stream()// - .collect(Collectors.toMap(// - Method::getName, // - method -> new Function(method, it), // - (left, right) -> right))) - .orElseGet(Collections::emptyMap); + public MultiValueMap getFunctions(Optional target) { + + return target.map( // + it -> methods.stream().collect( // + new MultiValueMapCollector<>( // + m -> NameAndArgumentCount.of(m), // + m -> new Function(m, it) // + ))) // + .orElseGet(() -> CollectionUtils.toMultiValueMap(Collections.emptyMap())); } /** diff --git a/src/main/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProvider.java b/src/main/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProvider.java index cb4fb1076..2b7615004 100644 --- a/src/main/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProvider.java +++ b/src/main/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProvider.java @@ -23,7 +23,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import java.util.stream.Collectors; @@ -60,6 +59,7 @@ import org.springframework.util.StringUtils; * @author Thomas Darimont * @author Oliver Gierke * @author Christoph Strobl + * @author Jens Schauder * @since 1.9 */ public class ExtensionAwareEvaluationContextProvider implements EvaluationContextProvider, ApplicationContextAware { @@ -316,10 +316,7 @@ public class ExtensionAwareEvaluationContextProvider implements EvaluationContex */ private Optional getMethodExecutor(EvaluationContextExtensionAdapter adapter, String name, List argumentTypes) { - - return adapter.getFunctions().entrySet().stream()// - .filter(entry -> entry.getKey().equals(name))// - .findFirst().map(Entry::getValue).map(FunctionMethodExecutor::new); + return adapter.getFunctions().get(name, argumentTypes).map(FunctionMethodExecutor::new); } /** @@ -388,7 +385,7 @@ public class ExtensionAwareEvaluationContextProvider implements EvaluationContex private final EvaluationContextExtension extension; - private final Map functions; + private final Functions functions = new Functions(); private final Map properties; /** @@ -401,17 +398,16 @@ public class ExtensionAwareEvaluationContextProvider implements EvaluationContex public EvaluationContextExtensionAdapter(EvaluationContextExtension extension, EvaluationContextExtensionInformation information) { - Assert.notNull(extension, "Extenstion must not be null!"); + Assert.notNull(extension, "Extension must not be null!"); Assert.notNull(information, "Extension information must not be null!"); Optional target = Optional.ofNullable(extension.getRootObject()); ExtensionTypeInformation extensionTypeInformation = information.getExtensionTypeInformation(); RootObjectInformation rootObjectInformation = information.getRootObjectInformation(target); - this.functions = new HashMap<>(); - this.functions.putAll(extensionTypeInformation.getFunctions()); - this.functions.putAll(rootObjectInformation.getFunctions(target)); - this.functions.putAll(extension.getFunctions()); + functions.addAll(extension.getFunctions()); + functions.addAll(rootObjectInformation.getFunctions(target)); + functions.addAll(extensionTypeInformation.getFunctions()); this.properties = new HashMap<>(); this.properties.putAll(extensionTypeInformation.getProperties()); @@ -435,7 +431,7 @@ public class ExtensionAwareEvaluationContextProvider implements EvaluationContex * * @return */ - public Map getFunctions() { + Functions getFunctions() { return this.functions; } diff --git a/src/main/java/org/springframework/data/repository/query/Functions.java b/src/main/java/org/springframework/data/repository/query/Functions.java new file mode 100644 index 000000000..2c0244619 --- /dev/null +++ b/src/main/java/org/springframework/data/repository/query/Functions.java @@ -0,0 +1,133 @@ +/* + * Copyright 2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.query; + +import lombok.AllArgsConstructor; +import lombok.Value; + +import java.lang.reflect.Method; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.springframework.core.convert.TypeDescriptor; +import org.springframework.data.repository.query.spi.Function; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; + +/** + * {@link MultiValueMap} like datastructure to keep lists of + * {@link org.springframework.data.repository.query.spi.Function}s indexed by name and argument list length, where the + * value lists are actually unique with respect to the signature. + * + * @author Jens Schauder + * @since 2.0 + */ +class Functions { + + private final MultiValueMap functions = CollectionUtils + .toMultiValueMap(new HashMap<>()); + + void addAll(Map newFunctions) { + + newFunctions.forEach((n, f) -> { + NameAndArgumentCount k = new NameAndArgumentCount(n, f.getParameterCount()); + List currentElements = get(k); + if (!contains(currentElements, f)) { + functions.add(k, f); + } + }); + } + + void addAll(MultiValueMap newFunctions) { + + newFunctions.forEach((k, list) -> { + List currentElements = get(k); + list.stream() // + .filter(f -> !contains(currentElements, f)) // + .forEach(f -> functions.add(k, f)); + }); + } + + List get(NameAndArgumentCount key) { + return functions.getOrDefault(key, Collections.emptyList()); + } + + /** + * Gets the function that best matches the parameters given. The {@code name} must match, and the + * {@code argumentTypes} must be compatible with parameter list of the function. In order to resolve ambiguity it + * checks for a method with exactly matching parameter list. + * + * @param name the name of the method + * @param argumentTypes types of arguments that the method must be able to accept + * @return a {@code Function} if a unique on gets found. {@code Optional.empty} if none matches. Throws + * {@link IllegalStateException} if multiple functions match the parameters. + */ + Optional get(String name, List argumentTypes) { + + Stream candidates = get(new NameAndArgumentCount(name, argumentTypes.size())).stream() // + .filter(f -> f.supports(argumentTypes)); + return bestMatch(candidates.collect(Collectors.toList()), argumentTypes); + } + + private static boolean contains(List elements, Function f) { + return elements.stream().anyMatch(f::isSignatureEqual); + } + + private static Optional bestMatch(List candidates, List argumentTypes) { + + if (candidates.isEmpty()) { + return Optional.empty(); + } + if (candidates.size() == 1) { + return Optional.of(candidates.get(0)); + } + + Optional exactMatch = candidates.stream().filter(f -> f.supportsExact(argumentTypes)).findFirst(); + if (!exactMatch.isPresent()) { + throw new IllegalStateException(createErrorMessage(candidates, argumentTypes)); + } + + return exactMatch; + } + + private static String createErrorMessage(List candidates, List argumentTypes) { + + String argumentTypeString = String.join( // + ",", // + argumentTypes.stream().map(TypeDescriptor::getName).collect(Collectors.toList())); + + String messageTemplate = "There are multiple matching methods of name '%s' for parameter types (%s), but no " + + "exact match. Make sure to provide only one matching overload or one with exactly those types."; + + return String.format(messageTemplate, candidates.get(0).getName(), argumentTypeString); + } + + @Value + @AllArgsConstructor + static class NameAndArgumentCount { + String name; + int count; + + static NameAndArgumentCount of(Method m) { + return new NameAndArgumentCount(m.getName(), m.getParameterCount()); + } + } +} diff --git a/src/main/java/org/springframework/data/repository/query/spi/Function.java b/src/main/java/org/springframework/data/repository/query/spi/Function.java index 84f8ff676..97b3006a5 100644 --- a/src/main/java/org/springframework/data/repository/query/spi/Function.java +++ b/src/main/java/org/springframework/data/repository/query/spi/Function.java @@ -1,5 +1,5 @@ /* - * Copyright 2014 the original author or authors. + * Copyright 2014-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.data.repository.query.spi; import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.util.Arrays; import java.util.List; import org.springframework.core.convert.TypeDescriptor; @@ -29,6 +30,7 @@ import org.springframework.util.TypeUtils; * * @author Thomas Darimont * @author Oliver Gierke + * @author Jens Schauder * @since 1.9 */ public class Function { @@ -115,4 +117,49 @@ public class Function { return true; } + + /** + * Returns the number of parameters required by the underlying method. + * + * @return + */ + public int getParameterCount() { + return method.getParameterCount(); + } + + /** + * Checks if the encapsulated method has exactly the argument types as those passed as an argument. + * + * @param argumentTypes a list of {@link TypeDescriptor}s to compare with the argument types of the method + * @return {@code true} if the types are equal, {@code false} otherwise. + */ + public boolean supportsExact(List argumentTypes) { + + if (method.getParameterCount() != argumentTypes.size()) { + return false; + } + + Class[] parameterTypes = method.getParameterTypes(); + + for (int i = 0; i < parameterTypes.length; i++) { + if (parameterTypes[i] != argumentTypes.get(i).getType()) { + return false; + } + } + + return true; + } + + /** + * Checks wether this {@code Function} has the same signature as another {@code Function}. + * + * @param other the {@code Function} to compare {@code this} with. + * + * @return {@code true} iff name and argument list are the same. + */ + public boolean isSignatureEqual(Function other) { + + return getName().equals(other.getName()) // + && Arrays.equals(method.getParameterTypes(), other.method.getParameterTypes()); + } } diff --git a/src/main/java/org/springframework/data/util/MultiValueMapCollector.java b/src/main/java/org/springframework/data/util/MultiValueMapCollector.java new file mode 100644 index 000000000..597d82911 --- /dev/null +++ b/src/main/java/org/springframework/data/util/MultiValueMapCollector.java @@ -0,0 +1,97 @@ +/* + * Copyright 2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.util; + +import lombok.NonNull; +import lombok.RequiredArgsConstructor; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.BinaryOperator; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collector; + +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; + +/** + * A {@link Collector} for building a {@link MultiValueMap} from a {@link java.util.stream.Stream}. + * + * @author Jens Schauder + * @since 2.0 + */ +@RequiredArgsConstructor +public class MultiValueMapCollector implements Collector, MultiValueMap> { + + @NonNull private final Function keyFunction; + @NonNull private final Function valueFunction; + + /* + * (non-Javadoc) + * @see java.util.stream.Collector#supplier() + */ + @Override + public Supplier> supplier() { + return () -> CollectionUtils.toMultiValueMap(new HashMap<>()); + } + + /* + * (non-Javadoc) + * @see java.util.stream.Collector#accumulator() + */ + @Override + public BiConsumer, T> accumulator() { + return (map, t) -> map.add(keyFunction.apply(t), valueFunction.apply(t)); + } + + /* + * (non-Javadoc) + * @see java.util.stream.Collector#combiner() + */ + @Override + public BinaryOperator> combiner() { + + return (map1, map2) -> { + + for (K key : map2.keySet()) { + map1.addAll(key, map2.get(key)); + } + + return map1; + }; + } + + /* + * (non-Javadoc) + * @see java.util.stream.Collector#finisher() + */ + @Override + public Function, MultiValueMap> finisher() { + return Function.identity(); + } + + /* + * (non-Javadoc) + * @see java.util.stream.Collector#characteristics() + */ + @Override + public Set characteristics() { + return EnumSet.of(Characteristics.IDENTITY_FINISH, Characteristics.UNORDERED); + } +} diff --git a/src/main/java/org/springframework/data/util/StreamUtils.java b/src/main/java/org/springframework/data/util/StreamUtils.java index ea84b2554..45eaa8499 100644 --- a/src/main/java/org/springframework/data/util/StreamUtils.java +++ b/src/main/java/org/springframework/data/util/StreamUtils.java @@ -23,11 +23,13 @@ import java.util.List; import java.util.Set; import java.util.Spliterator; import java.util.Spliterators; +import java.util.function.Function; import java.util.stream.Collector; import java.util.stream.Stream; import java.util.stream.StreamSupport; import org.springframework.util.Assert; +import org.springframework.util.MultiValueMap; /** * Spring Data specific Java {@link Stream} utility methods and classes. @@ -77,4 +79,15 @@ public interface StreamUtils { public static Collector> toUnmodifiableSet() { return collectingAndThen(toSet(), Collections::unmodifiableSet); } + + /** + * Returns a {@link Collector} to create a {@link MultiValueMap}. + * + * @param keyFunction {@link Function} to create a key from an element of the {@link java.util.stream.Stream} + * @param valueFunction {@link Function} to create a value from an element of the {@link java.util.stream.Stream} + */ + public static Collector, MultiValueMap> toMultiMap(Function keyFunction, + Function valueFunction) { + return new MultiValueMapCollector(keyFunction, valueFunction); + } } diff --git a/src/test/java/org/springframework/data/mapping/model/ConvertingPropertyAccessorUnitTests.java b/src/test/java/org/springframework/data/mapping/model/ConvertingPropertyAccessorUnitTests.java index ff809823f..6e45f8ce7 100755 --- a/src/test/java/org/springframework/data/mapping/model/ConvertingPropertyAccessorUnitTests.java +++ b/src/test/java/org/springframework/data/mapping/model/ConvertingPropertyAccessorUnitTests.java @@ -59,7 +59,8 @@ public class ConvertingPropertyAccessorUnitTests { Entity entity = new Entity(); entity.id = 1L; - assertThat(getIdProperty()).hasValueSatisfying(it -> assertThat(getAccessor(entity, CONVERSION_SERVICE).getProperty(it, String.class)).hasValue("1")); + assertThat(getIdProperty()).hasValueSatisfying( + it -> assertThat(getAccessor(entity, CONVERSION_SERVICE).getProperty(it, String.class)).hasValue("1")); } @Test // DATACMNS-596 diff --git a/src/test/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProviderUnitTests.java b/src/test/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProviderUnitTests.java index 53bc88e6c..e508222ba 100755 --- a/src/test/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProviderUnitTests.java +++ b/src/test/java/org/springframework/data/repository/query/ExtensionAwareEvaluationContextProviderUnitTests.java @@ -19,6 +19,7 @@ import static org.assertj.core.api.Assertions.*; import lombok.RequiredArgsConstructor; +import java.io.Serializable; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; @@ -28,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; import org.junit.Before; import org.junit.Test; import org.springframework.data.domain.PageRequest; @@ -44,7 +46,8 @@ import org.springframework.expression.spel.standard.SpelExpressionParser; * Unit tests {@link ExtensionAwareEvaluationContextProvider}. * * @author Oliver Gierke - * @author Thomas Darimont. + * @author Thomas Darimont + * @author Jens Schauder */ public class ExtensionAwareEvaluationContextProviderUnitTests { @@ -219,6 +222,68 @@ public class ExtensionAwareEvaluationContextProviderUnitTests { assertThat(counter.get()).isEqualTo(2); } + @Test // DATACMNS-1026 + public void overloadedMethodsGetResolved() throws Exception { + + provider = createContextProviderWithOverloads(); + + // from the root object + assertThat(evaluateExpression("method()")).isEqualTo("zero"); + assertThat(evaluateExpression("method(23)")).isEqualTo("single-int"); + assertThat(evaluateExpression("method('hello')")).isEqualTo("single-string"); + assertThat(evaluateExpression("method('one', 'two')")).isEqualTo("two"); + + // from the extension + assertThat(evaluateExpression("method(1, 2)")).isEqualTo("two-ints"); + assertThat(evaluateExpression("method(1, 'two')")).isEqualTo("int-and-string"); + } + + @Test // DATACMNS-1026 + public void methodFromRootObjectOverwritesMethodFromExtension() throws Exception { + + provider = createContextProviderWithOverloads(); + + assertThat(evaluateExpression("ambiguous()")).isEqualTo("from-root"); + } + + @Test // DATACMNS-1026 + public void aliasedMethodOverwritesMethodFromRootObject() throws Exception { + + provider = createContextProviderWithOverloads(); + + assertThat(evaluateExpression("aliasedMethod()")).isEqualTo("methodResult"); + } + + @Test // DATACMNS-1026 + public void exactMatchIsPreferred() throws Exception { + + provider = createContextProviderWithOverloads(); + + assertThat(evaluateExpression("ambiguousOverloaded('aString')")).isEqualTo("string"); + } + + @Test // DATACMNS-1026 + public void throwsExceptionWhenStillAmbiguous() throws Exception { + + provider = createContextProviderWithOverloads(); + + assertThatExceptionOfType(IllegalStateException.class) // + .isThrownBy(() -> evaluateExpression("ambiguousOverloaded(23)")) // + .withMessageContaining("ambiguousOverloaded") // + .withMessageContaining("(java.lang.Integer)"); + } + + private ExtensionAwareEvaluationContextProvider createContextProviderWithOverloads() { + + return new ExtensionAwareEvaluationContextProvider(Collections.singletonList( // + new DummyExtension("_first", "first") { + @Override + public Object getRootObject() { + return new RootWithOverloads(); + } + })); + } + @RequiredArgsConstructor public static class DummyExtension extends EvaluationContextExtensionSupport { @@ -269,6 +334,22 @@ public class ExtensionAwareEvaluationContextProviderUnitTests { public static String extensionMethod() { return "methodResult"; } + + public static String method(int i1, int i2) { + return "two-ints"; + } + + public static String method(int i, String s) { + return "int-and-string"; + } + + public static String ambiguous() { + return "from-extension-type"; + } + + public static String ambiguousToo() { + return "from-extension-type"; + } } private Object evaluateExpression(String expression) { @@ -312,4 +393,43 @@ public class ExtensionAwareEvaluationContextProviderUnitTests { return "rootObjectInstanceMethod2"; } } + + public static class RootWithOverloads { + + public String method() { + return "zero"; + } + + public String method(String s) { + return "single-string"; + } + + public String method(int i) { + return "single-int"; + } + + public String method(String s1, String s2) { + return "two"; + } + + public String ambiguous() { + return "from-root"; + } + + public String aliasedMethod() { + return "from-root"; + } + + public String ambiguousOverloaded(String s) { + return "string"; + } + + public String ambiguousOverloaded(Object o) { + return "object"; + } + + public String ambiguousOverloaded(Serializable o) { + return "serializable"; + } + } }