diff --git a/src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java b/src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java index 6b88777a5..740e73291 100644 --- a/src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java +++ b/src/main/java/org/springframework/data/repository/core/support/AbstractRepositoryMetadata.java @@ -18,8 +18,10 @@ package org.springframework.data.repository.core.support; import java.lang.reflect.Method; import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Set; +import org.springframework.core.KotlinDetector; import org.springframework.data.domain.Pageable; import org.springframework.data.repository.Repository; import org.springframework.data.repository.core.CrudMethods; @@ -27,6 +29,7 @@ import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.util.QueryExecutionConverters; import org.springframework.data.repository.util.ReactiveWrappers; import org.springframework.data.util.ClassTypeInformation; +import org.springframework.data.util.KotlinReflectionUtils; import org.springframework.data.util.Lazy; import org.springframework.data.util.TypeInformation; import org.springframework.util.Assert; @@ -79,7 +82,20 @@ public abstract class AbstractRepositoryMetadata implements RepositoryMetadata { * @see org.springframework.data.repository.core.RepositoryMetadata#getReturnedDomainClass(java.lang.reflect.Method) */ public Class getReturnedDomainClass(Method method) { - return QueryExecutionConverters.unwrapWrapperTypes(typeInformation.getReturnType(method)).getType(); + + TypeInformation returnType = null; + if (KotlinDetector.isKotlinType(method.getDeclaringClass()) && KotlinReflectionUtils.isSuspend(method)) { + + // the last parameter is Continuation or Continuation> + List> types = typeInformation.getParameterTypes(method); + returnType = types.get(types.size() - 1).getComponentType(); + } + + if (returnType == null) { + returnType = typeInformation.getReturnType(method); + } + + return QueryExecutionConverters.unwrapWrapperTypes(returnType).getType(); } /* diff --git a/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineRepositoryMetadataUnitTests.kt b/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineRepositoryMetadataUnitTests.kt new file mode 100644 index 000000000..adaa77a8e --- /dev/null +++ b/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineRepositoryMetadataUnitTests.kt @@ -0,0 +1,79 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.kotlin + +import kotlinx.coroutines.flow.Flow +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test +import org.springframework.data.repository.core.support.DefaultRepositoryMetadata +import org.springframework.data.repository.sample.User +import org.springframework.data.util.ReflectionUtils +import kotlin.coroutines.Continuation + +/** + * Unit tests for [org.springframework.data.repository.core.RepositoryMetadata]. + * + * @author Mark Paluch + */ +class CoroutineRepositoryMetadataUnitTests { + + @Test // DATACMNS-1689 + fun shouldDetermineCorrectResultType() { + + val metadata = DefaultRepositoryMetadata(MyCoRepository::class.java) + val method = ReflectionUtils.findRequiredMethod(MyCoRepository::class.java, "findOne", String::class.java, Continuation::class.java); + + assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(User::class.java) + } + + @Test // DATACMNS-1689 + fun shouldDetermineCorrectOptionalResultType() { + + val metadata = DefaultRepositoryMetadata(MyCoRepository::class.java) + val method = ReflectionUtils.findRequiredMethod(MyCoRepository::class.java, "findOneOptional", String::class.java, Continuation::class.java); + + assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(User::class.java) + } + + @Test // DATACMNS-1689 + fun shouldDetermineCorrectFlowResultType() { + + val metadata = DefaultRepositoryMetadata(MyCoRepository::class.java) + val method = ReflectionUtils.findRequiredMethod(MyCoRepository::class.java, "findMultiple", String::class.java); + + assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(User::class.java) + } + + @Test // DATACMNS-1689 + fun shouldDetermineCorrectSuspendedFlowResultType() { + + val metadata = DefaultRepositoryMetadata(MyCoRepository::class.java) + val method = ReflectionUtils.findRequiredMethod(MyCoRepository::class.java, "findMultipleSuspended", String::class.java, Continuation::class.java); + + assertThat(metadata.getReturnedDomainClass(method)).isEqualTo(User::class.java) + } + + interface MyCoRepository : CoroutineCrudRepository { + + suspend fun findOne(id: String): User + + suspend fun findOneOptional(id: String): User? + + fun findMultiple(id: String): Flow + + suspend fun findMultipleSuspended(id: String): Flow + } +}