diff --git a/src/main/java/org/springframework/data/repository/support/DefaultRepositoryInvokerFactory.java b/src/main/java/org/springframework/data/repository/support/DefaultRepositoryInvokerFactory.java index 24b0c3efd..6ef781510 100644 --- a/src/main/java/org/springframework/data/repository/support/DefaultRepositoryInvokerFactory.java +++ b/src/main/java/org/springframework/data/repository/support/DefaultRepositoryInvokerFactory.java @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap; import org.springframework.core.convert.ConversionService; import org.springframework.data.repository.CrudRepository; import org.springframework.data.repository.PagingAndSortingRepository; +import org.springframework.data.repository.reactive.ReactiveCrudRepository; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.format.support.DefaultFormattingConversionService; import org.springframework.util.Assert; @@ -95,7 +96,10 @@ public class DefaultRepositoryInvokerFactory implements RepositoryInvokerFactory @SuppressWarnings("unchecked") protected RepositoryInvoker createInvoker(RepositoryInformation information, Object repository) { - if (repository instanceof PagingAndSortingRepository && repository instanceof CrudRepository) { + if (repository instanceof ReactiveCrudRepository) { + return new ReactiveCrudRepositoryInvoker((ReactiveCrudRepository) repository, information, + conversionService); + } else if (repository instanceof PagingAndSortingRepository && repository instanceof CrudRepository) { return new PagingAndSortingRepositoryInvoker((PagingAndSortingRepository) repository, information, conversionService); } else if (repository instanceof CrudRepository) { diff --git a/src/main/java/org/springframework/data/repository/support/ReactiveCrudRepositoryInvoker.java b/src/main/java/org/springframework/data/repository/support/ReactiveCrudRepositoryInvoker.java new file mode 100644 index 000000000..0603e0da5 --- /dev/null +++ b/src/main/java/org/springframework/data/repository/support/ReactiveCrudRepositoryInvoker.java @@ -0,0 +1,104 @@ +/* + * Copyright 2026-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.support; + +import java.lang.reflect.Method; +import java.util.Optional; + +import org.springframework.core.convert.ConversionService; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Sort; +import org.springframework.data.repository.core.CrudMethods; +import org.springframework.data.repository.core.RepositoryMetadata; +import org.springframework.data.repository.reactive.ReactiveCrudRepository; + +/** + * {@link RepositoryInvoker} to shortcut execution of CRUD methods into direct calls on a + * {@link ReactiveCrudRepository}. This invoker blocks on reactive types to provide synchronous + * execution semantics required by the repository populator infrastructure. + * + * @author Gichan Im + * @since 4.1 + * @see ReactiveCrudRepository + */ +class ReactiveCrudRepositoryInvoker extends ReflectionRepositoryInvoker { + + private final ReactiveCrudRepository repository; + + private final boolean customSaveMethod; + private final boolean customFindOneMethod; + private final boolean customFindAllMethod; + private final boolean customDeleteMethod; + + /** + * Creates a new {@link ReactiveCrudRepositoryInvoker} for the given {@link ReactiveCrudRepository}, + * {@link RepositoryMetadata} and {@link ConversionService}. + * + * @param repository must not be {@literal null}. + * @param metadata must not be {@literal null}. + * @param conversionService must not be {@literal null}. + */ + public ReactiveCrudRepositoryInvoker(ReactiveCrudRepository repository, RepositoryMetadata metadata, + ConversionService conversionService) { + + super(repository, metadata, conversionService); + + CrudMethods crudMethods = metadata.getCrudMethods(); + + this.customSaveMethod = isRedeclaredMethod(crudMethods.getSaveMethod()); + this.customFindOneMethod = isRedeclaredMethod(crudMethods.getFindOneMethod()); + this.customDeleteMethod = isRedeclaredMethod(crudMethods.getDeleteMethod()); + this.customFindAllMethod = isRedeclaredMethod(crudMethods.getFindAllMethod()); + this.repository = repository; + } + + @Override + public Iterable invokeFindAll(Sort sort) { + return customFindAllMethod ? super.invokeFindAll(sort) : repository.findAll().collectList().block(); + } + + @Override + public Iterable invokeFindAll(Pageable pageable) { + return customFindAllMethod ? super.invokeFindAll(pageable) : repository.findAll().collectList().block(); + } + + @Override + @SuppressWarnings("unchecked") + public Optional invokeFindById(Object id) { + return customFindOneMethod ? super.invokeFindById(id) + : (Optional) repository.findById(convertId(id)).blockOptional(); + } + + @Override + @SuppressWarnings("unchecked") + public T invokeSave(T entity) { + return customSaveMethod ? super.invokeSave(entity) : (T) repository.save(entity).block(); + } + + @Override + public void invokeDeleteById(Object id) { + + if (customDeleteMethod) { + super.invokeDeleteById(id); + } else { + repository.deleteById(convertId(id)).block(); + } + } + + private static boolean isRedeclaredMethod(Optional method) { + return method.map(it -> !it.getDeclaringClass().equals(ReactiveCrudRepository.class)).orElse(false); + } +} diff --git a/src/test/java/org/springframework/data/repository/support/ReactiveCrudRepositoryInvokerUnitTests.java b/src/test/java/org/springframework/data/repository/support/ReactiveCrudRepositoryInvokerUnitTests.java new file mode 100644 index 000000000..20d948955 --- /dev/null +++ b/src/test/java/org/springframework/data/repository/support/ReactiveCrudRepositoryInvokerUnitTests.java @@ -0,0 +1,138 @@ +/* + * Copyright 2026-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.repository.support; + +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; +import static org.springframework.data.repository.support.RepositoryInvocationTestUtils.*; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import org.springframework.core.convert.support.GenericConversionService; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Sort; +import org.springframework.data.repository.core.RepositoryMetadata; +import org.springframework.data.repository.core.support.DefaultRepositoryMetadata; +import org.springframework.data.repository.reactive.ReactiveCrudRepository; +import org.springframework.format.support.DefaultFormattingConversionService; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Unit tests for {@link ReactiveCrudRepositoryInvoker}. + * + * @author Gichan Im + */ +@ExtendWith(MockitoExtension.class) +class ReactiveCrudRepositoryInvokerUnitTests { + + @Mock ReactivePersonRepository repository; + @Mock ReactiveOrderRepository orderRepository; + + @Test // GH-2090 + void invokesRedeclaredSave() { + + when(orderRepository.save(any())).thenReturn(Mono.just(new Order())); + + getInvokerFor(orderRepository, expectInvocationOnType(ReactiveOrderRepository.class)).invokeSave(new Order()); + } + + @Test // GH-2090 + void invokesRedeclaredFindById() { + + when(orderRepository.findById(any(Long.class))).thenReturn(Mono.empty()); + + getInvokerFor(orderRepository, expectInvocationOnType(ReactiveOrderRepository.class)).invokeFindById(1L); + } + + @Test // GH-2090 + void invokesRedeclaredDelete() { + + when(orderRepository.deleteById(any(Long.class))).thenReturn(Mono.empty()); + + getInvokerFor(orderRepository, expectInvocationOnType(ReactiveOrderRepository.class)).invokeDeleteById(1L); + } + + @Test // GH-2090 + void invokesSaveOnReactiveCrudRepository() throws Exception { + + when(repository.save(any())).thenReturn(Mono.just(new Person())); + + var method = ReactiveCrudRepository.class.getMethod("save", Object.class); + getInvokerFor(repository, expectInvocationOf(method)).invokeSave(new Person()); + } + + @Test // GH-2090 + void invokesFindByIdOnReactiveCrudRepository() throws Exception { + + when(repository.findById(any(Long.class))).thenReturn(Mono.empty()); + + var method = ReactiveCrudRepository.class.getMethod("findById", Object.class); + getInvokerFor(repository, expectInvocationOf(method)).invokeFindById(1L); + } + + @Test // GH-2090 + void invokesDeleteByIdOnReactiveCrudRepository() throws Exception { + + when(repository.deleteById(any(Long.class))).thenReturn(Mono.empty()); + + var method = ReactiveCrudRepository.class.getMethod("deleteById", Object.class); + getInvokerFor(repository, expectInvocationOf(method)).invokeDeleteById(1L); + } + + @Test // GH-2090 + void invokesFindAllOnReactiveCrudRepository() throws Exception { + + when(repository.findAll()).thenReturn(Flux.empty()); + + var method = ReactiveCrudRepository.class.getMethod("findAll"); + getInvokerFor(repository, expectInvocationOf(method)).invokeFindAll(Pageable.unpaged()); + getInvokerFor(repository, expectInvocationOf(method)).invokeFindAll(Sort.unsorted()); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private static RepositoryInvoker getInvokerFor(Object repository, VerifyingMethodInterceptor interceptor) { + + var proxy = getVerifyingRepositoryProxy(repository, interceptor); + + RepositoryMetadata metadata = new DefaultRepositoryMetadata(repository.getClass().getInterfaces()[0]); + GenericConversionService conversionService = new DefaultFormattingConversionService(); + + return new ReactiveCrudRepositoryInvoker((ReactiveCrudRepository) proxy, metadata, conversionService); + } + + static class Person {} + + static class Order {} + + interface ReactivePersonRepository extends ReactiveCrudRepository {} + + interface ReactiveOrderRepository extends ReactiveCrudRepository { + + @Override + Mono save(S entity); + + @Override + Mono findById(Long id); + + @Override + Mono deleteById(Long id); + } +}