diff --git a/src/main/java/org/springframework/data/repository/util/ClassUtils.java b/src/main/java/org/springframework/data/repository/util/ClassUtils.java index ca5952eb5..cd9915716 100644 --- a/src/main/java/org/springframework/data/repository/util/ClassUtils.java +++ b/src/main/java/org/springframework/data/repository/util/ClassUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2008-2013 the original author or authors. + * Copyright 2008-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,9 @@ import java.util.Arrays; import java.util.Collection; import org.springframework.data.repository.Repository; +import org.springframework.data.util.ClassTypeInformation; +import org.springframework.data.util.TypeInformation; +import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; @@ -34,9 +37,7 @@ public abstract class ClassUtils { /** * Private constructor to prevent instantiation. */ - private ClassUtils() { - - } + private ClassUtils() {} /** * Returns whether the given class contains a property with the given name. @@ -96,15 +97,22 @@ public abstract class ClassUtils { } /** - * Asserts the given {@link Method}'s return type to be one of the given types. + * Asserts the given {@link Method}'s return type to be one of the given types. Will unwrap known wrapper types before + * the assignment check (see {@link QueryExecutionConverters}). * - * @param method - * @param types + * @param method must not be {@literal null}. + * @param types must not be {@literal null} or empty. */ public static void assertReturnTypeAssignable(Method method, Class... types) { + Assert.notNull(method, "Method must not be null!"); + Assert.notEmpty(types, "Types must not be null or empty!"); + + TypeInformation returnType = ClassTypeInformation.fromReturnTypeOf(method); + returnType = QueryExecutionConverters.supports(returnType.getType()) ? returnType.getComponentType() : returnType; + for (Class type : types) { - if (type.isAssignableFrom(method.getReturnType())) { + if (type.isAssignableFrom(returnType.getType())) { return; } } diff --git a/src/test/java/org/springframework/data/repository/util/ClassUtilsUnitTests.java b/src/test/java/org/springframework/data/repository/util/ClassUtilsUnitTests.java index c8266b11d..a6998a6dd 100644 --- a/src/test/java/org/springframework/data/repository/util/ClassUtilsUnitTests.java +++ b/src/test/java/org/springframework/data/repository/util/ClassUtilsUnitTests.java @@ -19,13 +19,16 @@ import static org.junit.Assert.*; import static org.springframework.data.repository.util.ClassUtils.*; import java.io.Serializable; +import java.lang.reflect.Method; import java.util.List; import java.util.Map; +import java.util.concurrent.Future; import org.junit.Test; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.repository.Repository; +import org.springframework.scheduling.annotation.Async; /** * Unit test for {@link ClassUtils}. @@ -48,6 +51,17 @@ public class ClassUtilsUnitTests { assertFalse(hasProperty(User.class, "address")); } + /** + * @see DATACMNS-769 + */ + @Test + public void unwrapsWrapperTypesBeforeAssignmentCheck() throws Exception { + + Method method = UserRepository.class.getMethod("findAsync", Pageable.class); + + assertReturnTypeAssignable(method, Page.class); + } + @SuppressWarnings("unused") private class User { @@ -61,6 +75,8 @@ public class ClassUtilsUnitTests { static interface UserRepository extends Repository { + @Async + Future> findAsync(Pageable pageable); } interface SomeDao extends Serializable, UserRepository {