From e1bbdf09139dca7c21ec64e140bcc3bda463b2f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Basl=C3=A9?= Date: Thu, 14 Dec 2023 11:43:37 +0100 Subject: [PATCH] Add support for bean overriding in tests This commit introduces two sets of annotations (`@TestBean` on one side and `MockitoBean`/`MockitoSpyBean` on the other side), as well as an extension mecanism based on the `@BeanOverride` meta-annotation. Extension implementors are expected to only provide an annotation, a BeanOverrideProcessor implementation and an OverrideMetadata subclass. Closes gh-29917. --- framework-docs/modules/ROOT/nav.adoc | 1 + .../annotations/integration-spring.adoc | 2 + .../annotation-beanoverriding.adoc | 130 ++++++ spring-test/spring-test.gradle | 2 + .../test/bean/override/BeanOverride.java | 46 +++ .../BeanOverrideBeanPostProcessor.java | 370 ++++++++++++++++++ .../BeanOverrideContextCustomizerFactory.java | 100 +++++ .../bean/override/BeanOverrideParser.java | 141 +++++++ .../bean/override/BeanOverrideProcessor.java | 70 ++++ .../bean/override/BeanOverrideStrategy.java | 45 +++ .../BeanOverrideTestExecutionListener.java | 107 +++++ .../test/bean/override/OverrideMetadata.java | 153 ++++++++ .../bean/override/convention/TestBean.java | 78 ++++ .../convention/TestBeanOverrideProcessor.java | 145 +++++++ .../override/convention/package-info.java | 11 + .../bean/override/mockito/Definition.java | 118 ++++++ .../bean/override/mockito/MockDefinition.java | 170 ++++++++ .../test/bean/override/mockito/MockReset.java | 139 +++++++ .../bean/override/mockito/MockitoBean.java | 86 ++++ .../mockito/MockitoBeanOverrideProcessor.java | 38 ++ .../bean/override/mockito/MockitoBeans.java | 41 ++ .../MockitoResetTestExecutionListener.java | 126 ++++++ .../bean/override/mockito/MockitoSpyBean.java | 84 ++++ .../mockito/MockitoTestExecutionListener.java | 139 +++++++ .../bean/override/mockito/SpyDefinition.java | 145 +++++++ .../bean/override/mockito/package-info.java | 9 + .../test/bean/override/package-info.java | 9 + .../main/resources/META-INF/spring.factories | 4 + .../BeanOverrideBeanPostProcessorTests.java | 328 ++++++++++++++++ .../override/BeanOverrideParserTests.java | 122 ++++++ .../bean/override/OverrideMetadataTests.java | 68 ++++ .../TestBeanOverrideProcessorTests.java | 130 ++++++ .../ExampleBeanOverrideAnnotation.java | 38 ++ .../example/ExampleBeanOverrideProcessor.java | 49 +++ .../bean/override/example/ExampleService.java | 28 ++ .../example/FailingExampleService.java | 34 ++ .../override/example/RealExampleService.java | 37 ++ .../TestBeanOverrideMetaAnnotation.java | 27 ++ .../example/TestOverrideMetadata.java | 119 ++++++ .../bean/override/example/package-info.java | 9 + .../context/TestExecutionListenersTests.java | 21 +- 41 files changed, 3516 insertions(+), 3 deletions(-) create mode 100644 framework-docs/modules/ROOT/pages/testing/annotations/integration-spring/annotation-beanoverriding.adoc create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/BeanOverride.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideBeanPostProcessor.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideContextCustomizerFactory.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideParser.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideProcessor.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideStrategy.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideTestExecutionListener.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/OverrideMetadata.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/convention/TestBean.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/convention/TestBeanOverrideProcessor.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/convention/package-info.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/mockito/Definition.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockDefinition.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockReset.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoBean.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoBeanOverrideProcessor.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoBeans.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoResetTestExecutionListener.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoSpyBean.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoTestExecutionListener.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/mockito/SpyDefinition.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/mockito/package-info.java create mode 100644 spring-test/src/main/java/org/springframework/test/bean/override/package-info.java create mode 100644 spring-test/src/test/java/org/springframework/test/bean/override/BeanOverrideBeanPostProcessorTests.java create mode 100644 spring-test/src/test/java/org/springframework/test/bean/override/BeanOverrideParserTests.java create mode 100644 spring-test/src/test/java/org/springframework/test/bean/override/OverrideMetadataTests.java create mode 100644 spring-test/src/test/java/org/springframework/test/bean/override/convention/TestBeanOverrideProcessorTests.java create mode 100644 spring-test/src/test/java/org/springframework/test/bean/override/example/ExampleBeanOverrideAnnotation.java create mode 100644 spring-test/src/test/java/org/springframework/test/bean/override/example/ExampleBeanOverrideProcessor.java create mode 100644 spring-test/src/test/java/org/springframework/test/bean/override/example/ExampleService.java create mode 100644 spring-test/src/test/java/org/springframework/test/bean/override/example/FailingExampleService.java create mode 100644 spring-test/src/test/java/org/springframework/test/bean/override/example/RealExampleService.java create mode 100644 spring-test/src/test/java/org/springframework/test/bean/override/example/TestBeanOverrideMetaAnnotation.java create mode 100644 spring-test/src/test/java/org/springframework/test/bean/override/example/TestOverrideMetadata.java create mode 100644 spring-test/src/test/java/org/springframework/test/bean/override/example/package-info.java diff --git a/framework-docs/modules/ROOT/nav.adoc b/framework-docs/modules/ROOT/nav.adoc index 611de4bd83e..9a4b979caec 100644 --- a/framework-docs/modules/ROOT/nav.adoc +++ b/framework-docs/modules/ROOT/nav.adoc @@ -183,6 +183,7 @@ ***** xref:testing/annotations/integration-spring/annotation-sqlmergemode.adoc[] ***** xref:testing/annotations/integration-spring/annotation-sqlgroup.adoc[] ***** xref:testing/annotations/integration-spring/annotation-disabledinaotmode.adoc[] +***** xref:testing/annotations/integration-spring/annotation-beanoverriding.adoc[] **** xref:testing/annotations/integration-junit4.adoc[] **** xref:testing/annotations/integration-junit-jupiter.adoc[] **** xref:testing/annotations/integration-meta.adoc[] diff --git a/framework-docs/modules/ROOT/pages/testing/annotations/integration-spring.adoc b/framework-docs/modules/ROOT/pages/testing/annotations/integration-spring.adoc index 3804efbc56f..997f717ee73 100644 --- a/framework-docs/modules/ROOT/pages/testing/annotations/integration-spring.adoc +++ b/framework-docs/modules/ROOT/pages/testing/annotations/integration-spring.adoc @@ -28,4 +28,6 @@ Spring's testing annotations include the following: * xref:testing/annotations/integration-spring/annotation-sqlmergemode.adoc[`@SqlMergeMode`] * xref:testing/annotations/integration-spring/annotation-sqlgroup.adoc[`@SqlGroup`] * xref:testing/annotations/integration-spring/annotation-disabledinaotmode.adoc[`@DisabledInAotMode`] +* xref:testing/annotations/integration-spring/annotation-beanoverriding.adoc#spring-testing-annotation-beanoverriding-testbean[`@TestBean`] +* xref:testing/annotations/integration-spring/annotation-beanoverriding.adoc#spring-testing-annotation-beanoverriding-mockitobean[`@MockitoBean` and `@MockitoSpyBean`] diff --git a/framework-docs/modules/ROOT/pages/testing/annotations/integration-spring/annotation-beanoverriding.adoc b/framework-docs/modules/ROOT/pages/testing/annotations/integration-spring/annotation-beanoverriding.adoc new file mode 100644 index 00000000000..9abf9daf349 --- /dev/null +++ b/framework-docs/modules/ROOT/pages/testing/annotations/integration-spring/annotation-beanoverriding.adoc @@ -0,0 +1,130 @@ +[[spring-testing-annotation-beanoverriding]] += Bean Overriding in Tests + +Bean Overriding in Tests refers to the ability to override specific beans in the Context +for a test class, by annotating one or more fields in said test class. + +NOTE: This is intended as a less risky alternative to the practice of registering a bean via +`@Bean` with the `DefaultListableBeanFactory` `setAllowBeanDefinitionOverriding` set to +`true`. + +The Spring Testing Framework provides two sets of annotations presented below. One relies +purely on Spring, while the second set relies on the Mockito third party library. + +[[spring-testing-annotation-beanoverriding-testbean]] +== `@TestBean` + +`@TestBean` is used on a test class field to override a specific bean with an instance +provided by a conventionally named static method. + +By default, the bean name and the associated static method name are derived from the +annotated field's name, but the annotation allows for specific values to be provided. + +The `@TestBean` annotation uses the `REPLACE_DEFINITION` +xref:#spring-testing-annotation-beanoverriding-extending[strategy for test bean overriding]. + +The following example shows how to fully configure the `@TestBean` annotation, with +explicit values equivalent to the default: + +[tabs] +====== +Java:: ++ +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +---- + class OverrideBeanTests { + @TestBean(name = "service", methodName = "serviceTestOverride") // <1> + private CustomService service; + + // test case body... + + private static CustomService serviceTestOverride() { // <2> + return new MyFakeCustomService(); + } + } +---- +<1> Mark a field for bean overriding in this test class +<2> The result of this static method will be used as the instance and injected into the field +====== + + +[[spring-testing-annotation-beanoverriding-mockitobean]] +== `@MockitoBean` and `@MockitoSpyBean` + +`@MockitoBean` and `@MockitoSpyBean` are used on a test class field to override a bean +with a mocking and spying instance, respectively. In the later case, the original bean +definition is not replaced but instead an early instance is captured and wrapped by the +spy. + +By default, the name of the bean to override is derived from the annotated field's name, +but both annotations allows for a specific `name` to be provided. Each annotation also +defines Mockito-specific attributes to fine-tune the mocking details. + +The `@MockitoBean` annotation uses the `CREATE_OR_REPLACE_DEFINITION` +xref:#spring-testing-annotation-beanoverriding-extending[strategy for test bean overriding]. + +The `@MockitoSpyBean` annotation uses the `WRAP_EARLY_BEAN` +xref:#spring-testing-annotation-beanoverriding-extending[strategy] and the original instance +is wrapped in a Mockito spy. + +The following example shows how to configure the bean name for both `@MockitoBean` and +`@MockitoSpyBean` annotations: + +[tabs] +====== +Java:: ++ +[source,java,indent=0,subs="verbatim,quotes",role="primary"] +---- + class OverrideBeanTests { + @MockitoBean(name = "service1") // <1> + private CustomService mockService; + + @MockitoSpyBean(name = "service2") // <2> + private CustomService spyService; // <3> + + // test case body... + } +---- +<1> Mark `mockService` as a Mockito mock override of bean `service1` in this test class +<2> Mark `spyService` as a Mockito spy override of bean `service2` in this test class +<3> Both fields will be injected with the Mockito values (the mock and the spy respectively) +====== + + +[[spring-testing-annotation-beanoverriding-extending]] +== Extending bean override with a custom annotation + +The three annotations introduced above build upon the `@BeanOverride` meta-annotation +and associated infrastructure, which allows to define custom bean overriding variants. + +In order to provide an extension, three classes are needed: + - a concrete `BeanOverrideProcessor` `

` + - a concrete `OverrideMetadata` created by said processor + - an annotation meta-annotated with `@BeanOverride(P.class)` + +The Spring TestContext Framework includes infrastructure classes that support bean +overriding: a `BeanPostProcessor`, a `TestExecutionListener` and a `ContextCustomizerFactory`. +These are automatically registered via the Spring TestContext Framework `spring.factories` +file. + +The test classes are parsed looking for any field meta-annotated with `@BeanOverride`, +instantiating the relevant `BeanOverrideProcessor` in order to register an `OverrideMetadata`. + +Then the `BeanOverrideBeanPostProcessor` will use that information to alter the Context, +registering and replacing bean definitions as influenced by each metadata +`BeanOverrideStrategy`: + + - `REPLACE_DEFINITION`: the bean post-processor replaces the bean definition. +If it is not present in the context, an exception is thrown. + - `CREATE_OR_REPLACE_DEFINITION`: same as above but if the bean definition is not present +in the context, one is created + - `WRAP_EARLY_BEAN`: an original instance is obtained via +`SmartInstantiationAwareBeanPostProcessor#getEarlyBeanReference(Object, String)` and +provided to the processor during `OverrideMetadata` creation. + +NOTE: The Bean Overriding infrastructure works best with singleton beans. It also doesn't +include any bean resolution (unlike e.g. an `@Autowired`-annotated field). As such, the +name of the bean to override MUST be somehow provided to or computed by the +`BeanOverrideProcessor`. Typically, the end user provides the name as part of the custom +annotation's attributes, or the annotated field's name. \ No newline at end of file diff --git a/spring-test/spring-test.gradle b/spring-test/spring-test.gradle index ed316f9f77e..4dac2ee89ec 100644 --- a/spring-test/spring-test.gradle +++ b/spring-test/spring-test.gradle @@ -42,6 +42,7 @@ dependencies { optional("org.jetbrains.kotlinx:kotlinx-coroutines-reactor") optional("org.junit.jupiter:junit-jupiter-api") optional("org.junit.platform:junit-platform-launcher") // for AOT processing + optional("org.mockito:mockito-core") optional("org.seleniumhq.selenium:htmlunit-driver") { exclude group: "commons-logging", module: "commons-logging" exclude group: "net.bytebuddy", module: "byte-buddy" @@ -79,6 +80,7 @@ dependencies { testImplementation("org.hibernate:hibernate-validator") testImplementation("org.hsqldb:hsqldb") testImplementation("org.junit.platform:junit-platform-testkit") + testImplementation("org.mockito:mockito-core") testRuntimeOnly("com.sun.xml.bind:jaxb-core") testRuntimeOnly("com.sun.xml.bind:jaxb-impl") testRuntimeOnly("org.glassfish:jakarta.el") diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverride.java b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverride.java new file mode 100644 index 00000000000..114f8576950 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverride.java @@ -0,0 +1,46 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Mark an annotation as eligible for Bean Override parsing. + * This meta-annotation provides a {@link BeanOverrideProcessor} class which + * must be capable of handling the annotated annotation. + * + *

Target annotation must have a {@link RetentionPolicy} of {@code RUNTIME} + * and be applicable to {@link java.lang.reflect.Field Fields} only. + * @see BeanOverrideBeanPostProcessor + * + * @author Simon Baslé + * @since 6.2 + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.ANNOTATION_TYPE}) +public @interface BeanOverride { + + /** + * A {@link BeanOverrideProcessor} implementation class by which the target + * annotation should be processed. Implementations must have a no-argument + * constructor. + */ + Class value(); +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideBeanPostProcessor.java b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideBeanPostProcessor.java new file mode 100644 index 00000000000..e6561e1bba3 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideBeanPostProcessor.java @@ -0,0 +1,370 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override; + +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; + +import org.springframework.aop.scope.ScopedProxyUtils; +import org.springframework.beans.BeansException; +import org.springframework.beans.PropertyValues; +import org.springframework.beans.factory.BeanCreationException; +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.beans.factory.BeanFactoryUtils; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.config.ConstructorArgumentValues; +import org.springframework.beans.factory.config.InstantiationAwareBeanPostProcessor; +import org.springframework.beans.factory.config.RuntimeBeanReference; +import org.springframework.beans.factory.config.SmartInstantiationAwareBeanPostProcessor; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.context.ApplicationContext; +import org.springframework.core.Ordered; +import org.springframework.core.PriorityOrdered; +import org.springframework.core.ResolvableType; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ReflectionUtils; +import org.springframework.util.StringUtils; + +/** + * A {@link BeanFactoryPostProcessor} used to register and inject overriding + * bean metadata with the {@link ApplicationContext}. A set of + * {@link OverrideMetadata} must be passed to the processor. + * A {@link BeanOverrideParser} can typically be used to parse these from test + * classes that use any annotation meta-annotated with {@link BeanOverride} to + * mark override sites. + * + *

This processor supports two {@link BeanOverrideStrategy}: + *

+ * + *

This processor also provides support for injecting the overridden bean + * instances into their corresponding annotated {@link Field fields}. + * + * @author Simon Baslé + * @since 6.2 + */ +public class BeanOverrideBeanPostProcessor implements InstantiationAwareBeanPostProcessor, + BeanFactoryAware, BeanFactoryPostProcessor, Ordered { + + private static final String INFRASTRUCTURE_BEAN_NAME = BeanOverrideBeanPostProcessor.class.getName(); + private static final String EARLY_INFRASTRUCTURE_BEAN_NAME = BeanOverrideBeanPostProcessor.WrapEarlyBeanPostProcessor.class.getName(); + + private final Set overrideMetadata; + private final Map earlyOverrideMetadata = new HashMap<>(); + + private ConfigurableListableBeanFactory beanFactory; + + private final Map beanNameRegistry = new HashMap<>(); + + private final Map fieldRegistry = new HashMap<>(); + + /** + * Create a new {@link BeanOverrideBeanPostProcessor} instance with the + * given {@link OverrideMetadata} set. + * @param overrideMetadata the initial override metadata + */ + public BeanOverrideBeanPostProcessor(Set overrideMetadata) { + this.overrideMetadata = overrideMetadata; + } + + + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE - 10; + } + + @Override + public void setBeanFactory(BeanFactory beanFactory) throws BeansException { + Assert.isInstanceOf(ConfigurableListableBeanFactory.class, beanFactory, + "Beans overriding can only be used with a ConfigurableListableBeanFactory"); + this.beanFactory = (ConfigurableListableBeanFactory) beanFactory; + } + + /** + * Return this processor's {@link OverrideMetadata} set. + */ + protected Set getOverrideMetadata() { + return this.overrideMetadata; + } + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + Assert.state(this.beanFactory == beanFactory, "Unexpected beanFactory to postProcess"); + Assert.isInstanceOf(BeanDefinitionRegistry.class, beanFactory, + "Bean overriding annotations can only be used on bean factories that implement " + + "BeanDefinitionRegistry"); + postProcessWithRegistry((BeanDefinitionRegistry) beanFactory); + } + + private void postProcessWithRegistry(BeanDefinitionRegistry registry) { + //Note that a tracker bean is registered down the line only if there is some overrideMetadata parsed + Set overrideMetadata = getOverrideMetadata(); + for (OverrideMetadata metadata : overrideMetadata) { + registerBeanOverride(registry, metadata); + } + } + + /** + * Copy the details of a {@link BeanDefinition} to the definition created by + * this processor for a given {@link OverrideMetadata}. Defaults to copying + * the {@link BeanDefinition#isPrimary()} attribute and scope. + */ + protected void copyBeanDefinitionDetails(BeanDefinition from, RootBeanDefinition to) { + to.setPrimary(from.isPrimary()); + to.setScope(from.getScope()); + } + + private void registerBeanOverride(BeanDefinitionRegistry registry, OverrideMetadata overrideMetadata) { + switch (overrideMetadata.getBeanOverrideStrategy()) { + case REPLACE_DEFINITION -> registerReplaceDefinition(registry, overrideMetadata, true); + case REPLACE_OR_CREATE_DEFINITION -> registerReplaceDefinition(registry, overrideMetadata, false); + case WRAP_EARLY_BEAN -> registerWrapEarly(overrideMetadata); + } + } + + private void registerReplaceDefinition(BeanDefinitionRegistry registry, OverrideMetadata overrideMetadata, + boolean enforceExistingDefinition) { + RootBeanDefinition beanDefinition = createBeanDefinition(overrideMetadata); + String beanName = overrideMetadata.getExpectedBeanName(); + + BeanDefinition existingBeanDefinition = null; + if (registry.containsBeanDefinition(beanName)) { + existingBeanDefinition = registry.getBeanDefinition(beanName); + copyBeanDefinitionDetails(existingBeanDefinition, beanDefinition); + registry.removeBeanDefinition(beanName); + } + else if (enforceExistingDefinition) { + throw new IllegalStateException("Unable to override " + overrideMetadata.getBeanOverrideDescription() + + " bean, expected a bean definition to replace with name '" + beanName + "'"); + } + registry.registerBeanDefinition(beanName, beanDefinition); + + Object override = overrideMetadata.createOverride(beanName, existingBeanDefinition, null); + if (this.beanFactory.isSingleton(beanName)) { + // Now we have an instance (the override) that we can register. + // At this stage we don't expect a singleton instance to be present, + // and this call will throw if there is such an instance already. + this.beanFactory.registerSingleton(beanName, override); + } + + overrideMetadata.track(override, this.beanFactory); + this.beanNameRegistry.put(overrideMetadata, beanName); + this.fieldRegistry.put(overrideMetadata.field(), beanName); + } + + /** + * Check that the expected bean name is registered and matches the type to override. + * If so, put the override metadata in the early tracking map. + * The map will later be checked to see if a given bean should be wrapped + * upon creation, during the {@link WrapEarlyBeanPostProcessor#getEarlyBeanReference(Object, String)} + * phase + */ + private void registerWrapEarly(OverrideMetadata metadata) { + Set existingBeanNames = getExistingBeanNames(metadata.typeToOverride()); + String beanName = metadata.getExpectedBeanName(); + if (!existingBeanNames.contains(beanName)) { + throw new IllegalStateException("Unable to override wrap-early bean named '" + beanName + "', not found among " + + existingBeanNames); + } + this.earlyOverrideMetadata.put(beanName, metadata); + this.beanNameRegistry.put(metadata, beanName); + this.fieldRegistry.put(metadata.field(), beanName); + } + + /** + * Check early overrides records and use the {@link OverrideMetadata} to + * create an override instance from the provided bean, if relevant. + *

Called during the {@link SmartInstantiationAwareBeanPostProcessor} + * phases (see {@link WrapEarlyBeanPostProcessor#getEarlyBeanReference(Object, String)} + * and {@link WrapEarlyBeanPostProcessor#postProcessAfterInitialization(Object, String)}). + */ + protected final Object wrapIfNecessary(Object bean, String beanName) throws BeansException { + final OverrideMetadata metadata = this.earlyOverrideMetadata.get(beanName); + if (metadata != null && metadata.getBeanOverrideStrategy() == BeanOverrideStrategy.WRAP_EARLY_BEAN) { + bean = metadata.createOverride(beanName, null, bean); + metadata.track(bean, this.beanFactory); + } + return bean; + } + + private RootBeanDefinition createBeanDefinition(OverrideMetadata metadata) { + RootBeanDefinition definition = new RootBeanDefinition(metadata.typeToOverride().resolve()); + definition.setTargetType(metadata.typeToOverride()); + return definition; + } + + private Set getExistingBeanNames(ResolvableType resolvableType) { + Set beans = new LinkedHashSet<>( + Arrays.asList(this.beanFactory.getBeanNamesForType(resolvableType, true, false))); + Class type = resolvableType.resolve(Object.class); + for (String beanName : this.beanFactory.getBeanNamesForType(FactoryBean.class, true, false)) { + beanName = BeanFactoryUtils.transformedBeanName(beanName); + BeanDefinition beanDefinition = this.beanFactory.getBeanDefinition(beanName); + Object attribute = beanDefinition.getAttribute(FactoryBean.OBJECT_TYPE_ATTRIBUTE); + if (resolvableType.equals(attribute) || type.equals(attribute)) { + beans.add(beanName); + } + } + beans.removeIf(this::isScopedTarget); + return beans; + } + + private boolean isScopedTarget(String beanName) { + try { + return ScopedProxyUtils.isScopedTarget(beanName); + } + catch (Throwable ex) { + return false; + } + } + + private void postProcessField(Object bean, Field field) { + String beanName = this.fieldRegistry.get(field); + if (StringUtils.hasText(beanName)) { + inject(field, bean, beanName); + } + } + + @Override + public PropertyValues postProcessProperties(PropertyValues pvs, Object bean, String beanName) + throws BeansException { + ReflectionUtils.doWithFields(bean.getClass(), field -> postProcessField(bean, field)); + return pvs; + } + + void inject(Field field, Object target, OverrideMetadata overrideMetadata) { + String beanName = this.beanNameRegistry.get(overrideMetadata); + Assert.state(StringUtils.hasLength(beanName), () -> "No bean found for overrideMetadata " + overrideMetadata); + inject(field, target, beanName); + } + + private void inject(Field field, Object target, String beanName) { + try { + field.setAccessible(true); + Object existingValue = ReflectionUtils.getField(field, target); + Object bean = this.beanFactory.getBean(beanName, field.getType()); + if (existingValue == bean) { + return; + } + Assert.state(existingValue == null, () -> "The existing value '" + existingValue + + "' of field '" + field + "' is not the same as the new value '" + bean + "'"); + ReflectionUtils.setField(field, target, bean); + } + catch (Throwable ex) { + throw new BeanCreationException("Could not inject field '" + field + "'", ex); + } + } + + /** + * Register the processor with a {@link BeanDefinitionRegistry}. + * Not required when using the Spring TestContext Framework, as registration + * is automatic via the {@link org.springframework.core.io.support.SpringFactoriesLoader SpringFactoriesLoader} + * mechanism. + * @param registry the bean definition registry + * @param overrideMetadata the initial override metadata set + */ + public static void register(BeanDefinitionRegistry registry, @Nullable Set overrideMetadata) { + //early processor + getOrAddInfrastructureBeanDefinition(registry, WrapEarlyBeanPostProcessor.class, EARLY_INFRASTRUCTURE_BEAN_NAME, + constructorArguments -> constructorArguments.addIndexedArgumentValue(0, + new RuntimeBeanReference(INFRASTRUCTURE_BEAN_NAME))); + + //main processor + BeanDefinition definition = getOrAddInfrastructureBeanDefinition(registry, BeanOverrideBeanPostProcessor.class, + INFRASTRUCTURE_BEAN_NAME, constructorArguments -> constructorArguments + .addIndexedArgumentValue(0, new LinkedHashSet())); + ConstructorArgumentValues.ValueHolder constructorArg = definition.getConstructorArgumentValues() + .getIndexedArgumentValue(0, Set.class); + @SuppressWarnings("unchecked") + Set existing = (Set) constructorArg.getValue(); + if (overrideMetadata != null && existing != null) { + existing.addAll(overrideMetadata); + } + } + + private static BeanDefinition getOrAddInfrastructureBeanDefinition(BeanDefinitionRegistry registry, + Class clazz, String beanName, Consumer constructorArgumentsConsumer) { + if (!registry.containsBeanDefinition(beanName)) { + RootBeanDefinition definition = new RootBeanDefinition(clazz); + definition.setRole(BeanDefinition.ROLE_INFRASTRUCTURE); + ConstructorArgumentValues constructorArguments = definition.getConstructorArgumentValues(); + constructorArgumentsConsumer.accept(constructorArguments); + registry.registerBeanDefinition(beanName, definition); + return definition; + } + return registry.getBeanDefinition(beanName); + } + + private static final class WrapEarlyBeanPostProcessor implements SmartInstantiationAwareBeanPostProcessor, + PriorityOrdered { + + private final BeanOverrideBeanPostProcessor mainProcessor; + private final Map earlyReferences; + + private WrapEarlyBeanPostProcessor(BeanOverrideBeanPostProcessor mainProcessor) { + this.mainProcessor = mainProcessor; + this.earlyReferences = new ConcurrentHashMap<>(16); + } + + @Override + public int getOrder() { + return Ordered.HIGHEST_PRECEDENCE; + } + + @Override + public Object getEarlyBeanReference(Object bean, String beanName) throws BeansException { + if (bean instanceof FactoryBean) { + return bean; + } + this.earlyReferences.put(getCacheKey(bean, beanName), bean); + return this.mainProcessor.wrapIfNecessary(bean, beanName); + } + + @Override + public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { + if (bean instanceof FactoryBean) { + return bean; + } + if (this.earlyReferences.remove(getCacheKey(bean, beanName)) != bean) { + return this.mainProcessor.wrapIfNecessary(bean, beanName); + } + return bean; + } + + private String getCacheKey(Object bean, String beanName) { + return StringUtils.hasLength(beanName) ? beanName : bean.getClass().getName(); + } + + } +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideContextCustomizerFactory.java b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideContextCustomizerFactory.java new file mode 100644 index 00000000000..cf394301d75 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideContextCustomizerFactory.java @@ -0,0 +1,100 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override; + +import java.util.List; +import java.util.Set; + +import org.springframework.aot.hint.annotation.Reflective; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.test.context.ContextConfigurationAttributes; +import org.springframework.test.context.ContextCustomizer; +import org.springframework.test.context.ContextCustomizerFactory; +import org.springframework.test.context.MergedContextConfiguration; +import org.springframework.test.context.TestContextAnnotationUtils; + +/** + * A {@link ContextCustomizerFactory} to add support for Bean Overriding. + * + * @author Simon Baslé + * @since 6.2 + */ +public class BeanOverrideContextCustomizerFactory implements ContextCustomizerFactory { + + @Override + public ContextCustomizer createContextCustomizer(Class testClass, + List configAttributes) { + BeanOverrideParser parser = new BeanOverrideParser(); + parseMetadata(testClass, parser); + if (parser.getOverrideMetadata().isEmpty()) { + return null; + } + + return new BeanOverrideContextCustomizer(parser.getOverrideMetadata()); + } + + private void parseMetadata(Class testClass, BeanOverrideParser parser) { + parser.parse(testClass); + if (TestContextAnnotationUtils.searchEnclosingClass(testClass)) { + parseMetadata(testClass.getEnclosingClass(), parser); + } + } + + /** + * A {@link ContextCustomizer} for Bean Overriding in tests. + */ + @Reflective + static final class BeanOverrideContextCustomizer implements ContextCustomizer { + + private final Set metadata; + + /** + * Construct a context customizer given some pre-existing override + * metadata. + * @param metadata a set of concrete {@link OverrideMetadata} provided + * by the underlying {@link BeanOverrideParser} + */ + BeanOverrideContextCustomizer(Set metadata) { + this.metadata = metadata; + } + + @Override + public void customizeContext(ConfigurableApplicationContext context, MergedContextConfiguration mergedConfig) { + if (context instanceof BeanDefinitionRegistry registry) { + BeanOverrideBeanPostProcessor.register(registry, this.metadata); + } + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + BeanOverrideContextCustomizer other = (BeanOverrideContextCustomizer) obj; + return this.metadata.equals(other.metadata); + } + + @Override + public int hashCode() { + return this.metadata.hashCode(); + } + } +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideParser.java b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideParser.java new file mode 100644 index 00000000000..5a2d4acb3ac --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideParser.java @@ -0,0 +1,141 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.springframework.beans.factory.support.BeanDefinitionValidationException; +import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.core.annotation.MergedAnnotations; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; + +/** + * A parser that discovers annotations meta-annotated with {@link BeanOverride} + * on fields of a given class and creates {@link OverrideMetadata} accordingly. + * + * @author Simon Baslé + */ +class BeanOverrideParser { + + private final Set parsedMetadata; + + BeanOverrideParser() { + this.parsedMetadata = new LinkedHashSet<>(); + } + + /** + * Getter for the set of {@link OverrideMetadata} once {@link #parse(Class)} + * has been called. + */ + Set getOverrideMetadata() { + return Collections.unmodifiableSet(this.parsedMetadata); + } + + /** + * Discover fields of the provided class that are meta-annotated with + * {@link BeanOverride}, then instantiate their corresponding + * {@link BeanOverrideProcessor} and use it to create an {@link OverrideMetadata} + * instance for each field. Each call to {@code parse} adds the parsed + * metadata to the parser's override metadata {{@link #getOverrideMetadata()} + * set} + * @param testClass the class which fields to inspect + */ + void parse(Class testClass) { + ReflectionUtils.doWithFields(testClass, field -> parseField(field, testClass)); + } + + /** + * Check if any field of the provided {@code testClass} is meta-annotated + * with {@link BeanOverride}. + *

This is similar to the initial discovery of fields in {@link #parse(Class)} + * without the heavier steps of instantiating processors and creating + * {@link OverrideMetadata}, so this method leaves the current state of + * {@link #getOverrideMetadata()} unchanged. + * @param testClass the class which fields to inspect + * @return true if there is a bean override annotation present, false otherwise + * @see #parse(Class) + */ + boolean hasBeanOverride(Class testClass) { + AtomicBoolean hasBeanOverride = new AtomicBoolean(); + ReflectionUtils.doWithFields(testClass, field -> { + if (hasBeanOverride.get()) { + return; + } + final long count = MergedAnnotations.from(field, MergedAnnotations.SearchStrategy.DIRECT) + .stream(BeanOverride.class) + .count(); + hasBeanOverride.compareAndSet(false, count > 0L); + }); + return hasBeanOverride.get(); + } + + private void parseField(Field field, Class source) { + AtomicBoolean overrideAnnotationFound = new AtomicBoolean(); + + MergedAnnotations.from(field, MergedAnnotations.SearchStrategy.DIRECT) + .stream(BeanOverride.class) + .map(bo -> { + var a = bo.getMetaSource(); + Assert.notNull(a, "BeanOverride annotation must be meta-present"); + return new AnnotationPair(a.synthesize(), bo); + }) + .forEach(pair -> { + var metaAnnotation = pair.metaAnnotation().synthesize(); + final BeanOverrideProcessor processor = getProcessorInstance(metaAnnotation.value()); + if (processor == null) { + return; + } + ResolvableType typeToOverride = processor.getOrDeduceType(field, pair.annotation(), source); + + Assert.state(overrideAnnotationFound.compareAndSet(false, true), + "Multiple bean override annotations found on annotated field <" + field + ">"); + OverrideMetadata metadata = processor.createMetadata(field, pair.annotation(), typeToOverride); + boolean isNewDefinition = this.parsedMetadata.add(metadata); + Assert.state(isNewDefinition, () -> "Duplicate " + metadata.getBeanOverrideDescription() + + " overrideMetadata " + metadata); + }); + } + + @Nullable + private BeanOverrideProcessor getProcessorInstance(Class processorClass) { + final Constructor constructor = ClassUtils.getConstructorIfAvailable(processorClass); + if (constructor != null) { + ReflectionUtils.makeAccessible(constructor); + try { + return constructor.newInstance(); + } + catch (InstantiationException | IllegalAccessException | InvocationTargetException ex) { + throw new BeanDefinitionValidationException("Could not get an instance of BeanOverrideProcessor", ex); + } + } + return null; + } + + private record AnnotationPair(Annotation annotation, MergedAnnotation metaAnnotation) {} + +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideProcessor.java b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideProcessor.java new file mode 100644 index 00000000000..c4621737502 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideProcessor.java @@ -0,0 +1,70 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; +import java.lang.reflect.TypeVariable; + +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.MergedAnnotation; + +/** + * An interface for Bean Overriding concrete processing. + * Processors are generally linked to one or more specific concrete annotations + * (meta-annotated with {@link BeanOverride}) and specify different steps in the + * process of parsing these annotations, ultimately creating + * {@link OverrideMetadata} which will be used to instantiate the overrides. + * + *

Implementations are required to have a no-argument constructor and be + * stateless. + * + * @author Simon Baslé + * @since 6.2 + */ +@FunctionalInterface +public interface BeanOverrideProcessor { + + /** + * Determine a {@link ResolvableType} for which an {@link OverrideMetadata} + * instance will be created, e.g. by using the annotation to determine the + * type. + *

Defaults to the field corresponding {@link ResolvableType}, + * additionally tracking the source class if the field is a {@link TypeVariable}. + */ + default ResolvableType getOrDeduceType(Field field, Annotation annotation, Class source) { + return (field.getGenericType() instanceof TypeVariable) ? ResolvableType.forField(field, source) + : ResolvableType.forField(field); + } + + /** + * Create an {@link OverrideMetadata} for a given annotated field and target + * {@link #getOrDeduceType(Field, Annotation, Class) type}. + * Specific implementations of metadata can have state to be used during + * override {@link OverrideMetadata#createOverride(String, BeanDefinition, + * Object) instance creation} (e.g. from further parsing the annotation or + * the annotated field). + * @param field the annotated field + * @param overrideAnnotation the field annotation + * @param typeToOverride the target type + * @return a new {@link OverrideMetadata} + * @see #getOrDeduceType(Field, Annotation, Class) + * @see MergedAnnotation#synthesize() + */ + OverrideMetadata createMetadata(Field field, Annotation overrideAnnotation, ResolvableType typeToOverride); +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideStrategy.java b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideStrategy.java new file mode 100644 index 00000000000..32bf431495e --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideStrategy.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override; + +/** + * Strategies for override instantiation, implemented in + * {@link BeanOverrideBeanPostProcessor}. + * + * @author Simon Baslé + * @since 6.2 + */ +public enum BeanOverrideStrategy { + + /** + * Replace a given bean's definition, immediately preparing a singleton + * instance. Enforces the original bean definition to exist. + */ + REPLACE_DEFINITION, + /** + * Replace a given bean's definition, immediately preparing a singleton + * instance. If the original bean definition does not exist, create the + * override definition instead of failing. + */ + REPLACE_OR_CREATE_DEFINITION, + /** + * Intercept and wrap the actual bean instance upon creation, during + * {@link org.springframework.beans.factory.config.SmartInstantiationAwareBeanPostProcessor#getEarlyBeanReference(Object, String) + * early bean definition}. + */ + WRAP_EARLY_BEAN; +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideTestExecutionListener.java new file mode 100644 index 00000000000..dc86a686cad --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/BeanOverrideTestExecutionListener.java @@ -0,0 +1,107 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override; + +import java.lang.reflect.Field; +import java.util.function.BiConsumer; + +import org.springframework.test.context.TestContext; +import org.springframework.test.context.TestExecutionListener; +import org.springframework.test.context.support.AbstractTestExecutionListener; +import org.springframework.test.context.support.DependencyInjectionTestExecutionListener; +import org.springframework.util.ReflectionUtils; + +/** + * A {@link TestExecutionListener} that enables Bean Override support in + * tests, injecting overridden beans in appropriate fields. + * + *

Some flavors of Bean Override might additionally require the use of + * additional listeners, which should be mentioned in the annotation(s) javadoc. + * + * @author Simon Baslé + * @since 6.2 + */ +public class BeanOverrideTestExecutionListener extends AbstractTestExecutionListener { + + /** + * Executes almost last ({@code LOWEST_PRECEDENCE - 50}). + */ + @Override + public int getOrder() { + return LOWEST_PRECEDENCE - 50; + } + + @Override + public void prepareTestInstance(TestContext testContext) throws Exception { + injectFields(testContext); + } + + @Override + public void beforeTestMethod(TestContext testContext) throws Exception { + reinjectFieldsIfConfigured(testContext); + } + + /** + * Using a registered {@link BeanOverrideBeanPostProcessor}, find metadata + * associated with the current test class and ensure fields are injected + * with the overridden bean instance. + */ + protected void injectFields(TestContext testContext) { + postProcessFields(testContext, (testMetadata, postProcessor) -> postProcessor.inject( + testMetadata.overrideMetadata.field(), testMetadata.testInstance(), testMetadata.overrideMetadata())); + } + + /** + * Using a registered {@link BeanOverrideBeanPostProcessor}, find metadata + * associated with the current test class and ensure fields are nulled out + * then re-injected with the overridden bean instance. This method does + * nothing if the {@link DependencyInjectionTestExecutionListener#REINJECT_DEPENDENCIES_ATTRIBUTE} + * attribute is not present in the {@code testContext}. + */ + protected void reinjectFieldsIfConfigured(final TestContext testContext) throws Exception { + if (Boolean.TRUE.equals( + testContext.getAttribute(DependencyInjectionTestExecutionListener.REINJECT_DEPENDENCIES_ATTRIBUTE))) { + postProcessFields(testContext, (testMetadata, postProcessor) -> { + Field f = testMetadata.overrideMetadata.field(); + ReflectionUtils.makeAccessible(f); + ReflectionUtils.setField(f, testMetadata.testInstance(), null); + postProcessor.inject(f, testMetadata.testInstance(), testMetadata.overrideMetadata()); + }); + } + } + + private void postProcessFields(TestContext testContext, BiConsumer consumer) { + //avoid full parsing but validate that this particular class has some bean override field(s) + BeanOverrideParser parser = new BeanOverrideParser(); + if (parser.hasBeanOverride(testContext.getTestClass())) { + BeanOverrideBeanPostProcessor postProcessor = testContext.getApplicationContext() + .getBean(BeanOverrideBeanPostProcessor.class); + // the class should have already been parsed by the context customizer + for (OverrideMetadata metadata: postProcessor.getOverrideMetadata()) { + if (!metadata.field().getDeclaringClass().equals(testContext.getTestClass())) { + continue; + } + consumer.accept(new TestContextOverrideMetadata(testContext.getTestInstance(), metadata), + postProcessor); + } + } + } + + private record TestContextOverrideMetadata(Object testInstance, OverrideMetadata overrideMetadata) {} + +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/OverrideMetadata.java b/spring-test/src/main/java/org/springframework/test/bean/override/OverrideMetadata.java new file mode 100644 index 00000000000..4261441bef6 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/OverrideMetadata.java @@ -0,0 +1,153 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; +import java.util.Objects; + +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.SingletonBeanRegistry; +import org.springframework.core.ResolvableType; +import org.springframework.lang.Nullable; + +/** + * Metadata for Bean Overrides. + * + * @author Simon Baslé + * @since 6.2 + */ +public abstract class OverrideMetadata { + + private final Field field; + private final Annotation overrideAnnotation; + private final ResolvableType typeToOverride; + private final BeanOverrideStrategy strategy; + + public OverrideMetadata(Field field, Annotation overrideAnnotation, + ResolvableType typeToOverride, BeanOverrideStrategy strategy) { + this.field = field; + this.overrideAnnotation = overrideAnnotation; + this.typeToOverride = typeToOverride; + this.strategy = strategy; + } + + /** + * Define a short human-readable description of the kind of override this + * OverrideMetadata is about. This is especially useful for + * {@link BeanOverrideProcessor} that produce several subtypes of metadata + * (e.g. "mock" vs "spy"). + */ + public abstract String getBeanOverrideDescription(); + + /** + * Provide the expected bean name to override. Typically, this is either + * explicitly set in the concrete annotations or defined by the annotated + * field's name. + * @return the expected bean name, not null + */ + protected String getExpectedBeanName() { + return this.field.getName(); + } + + /** + * The field annotated with a {@link BeanOverride}-compatible annotation. + * @return the annotated field + */ + public Field field() { + return this.field; + } + + /** + * The concrete override annotation, i.e. the one meta-annotated with + * {@link BeanOverride}. + */ + public Annotation overrideAnnotation() { + return this.overrideAnnotation; + } + + /** + * The type to override, as a {@link ResolvableType}. + */ + public ResolvableType typeToOverride() { + return this.typeToOverride; + } + + /** + * Define the broad {@link BeanOverrideStrategy} for this + * {@link OverrideMetadata}, as a hint on how and when the override instance + * should be created. + */ + public final BeanOverrideStrategy getBeanOverrideStrategy() { + return this.strategy; + } + + /** + * Create an override instance from this {@link OverrideMetadata}, + * optionally provided with an existing {@link BeanDefinition} and/or an + * original instance (i.e. a singleton or an early wrapped instance). + * @param beanName the name of the bean being overridden + * @param existingBeanDefinition an existing bean definition for that bean + * name, or {@code null} if not relevant + * @param existingBeanInstance an existing instance for that bean name, + * for wrapping purpose, or {@code null} if irrelevant + * @return the instance with which to override the bean + */ + protected abstract Object createOverride(String beanName, @Nullable BeanDefinition existingBeanDefinition, + @Nullable Object existingBeanInstance); + + /** + * Optionally track objects created by this {@link OverrideMetadata} + * (default is no tracking). + * @param override the bean override instance to track + * @param trackingBeanRegistry the registry in which trackers could + * optionally be registered + */ + protected void track(Object override, SingletonBeanRegistry trackingBeanRegistry) { + //NO-OP + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj == null || !getClass().isAssignableFrom(obj.getClass())) { + return false; + } + var that = (OverrideMetadata) obj; + return Objects.equals(this.field, that.field) && + Objects.equals(this.overrideAnnotation, that.overrideAnnotation) && + Objects.equals(this.strategy, that.strategy) && + Objects.equals(this.typeToOverride, that.typeToOverride); + } + + @Override + public int hashCode() { + return Objects.hash(this.field, this.overrideAnnotation, this.strategy, this.typeToOverride); + } + + @Override + public String toString() { + return "OverrideMetadata[" + + "category=" + this.getBeanOverrideDescription() + ", " + + "field=" + this.field + ", " + + "overrideAnnotation=" + this.overrideAnnotation + ", " + + "strategy=" + this.strategy + ", " + + "typeToOverride=" + this.typeToOverride + ']'; + } +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/convention/TestBean.java b/spring-test/src/main/java/org/springframework/test/bean/override/convention/TestBean.java new file mode 100644 index 00000000000..d5d74214138 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/convention/TestBean.java @@ -0,0 +1,78 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.convention; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.test.bean.override.BeanOverride; + +/** + * Mark a field to represent a "method" bean override of the bean of the same + * name and inject the field with the overriding instance. + * + *

The instance is created from a static method in the declaring class which + * return type is compatible with the annotated field and which name follows the + * convention: + *

+ * + *

The annotated field's name is interpreted to be the name of the original + * bean to override, unless the annotation's {@link #name()} is specified. + * + * @see TestBeanOverrideProcessor + * @author Simon Baslé + * @since 6.2 + */ +@Target(ElementType.FIELD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@BeanOverride(TestBeanOverrideProcessor.class) +public @interface TestBean { + + /** + * The method suffix expected as a convention in static methods which + * provides an override instance. + */ + String CONVENTION_SUFFIX = "TestOverride"; + + /** + * The name of a static method to look for in the Configuration, which will + * be used to instantiate the override bean and inject the annotated field. + *

Default is {@code ""} (the empty String), which is translated into + * the annotated field's name concatenated with the + * {@link #CONVENTION_SUFFIX}. + */ + String methodName() default ""; + + /** + * The name of the original bean to override, or {@code ""} (the empty + * String) to deduce the name from the annotated field. + */ + String name() default ""; +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/convention/TestBeanOverrideProcessor.java b/spring-test/src/main/java/org/springframework/test/bean/override/convention/TestBeanOverrideProcessor.java new file mode 100644 index 00000000000..20eb05fb166 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/convention/TestBeanOverrideProcessor.java @@ -0,0 +1,145 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.convention; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.core.ResolvableType; +import org.springframework.lang.Nullable; +import org.springframework.test.bean.override.BeanOverrideProcessor; +import org.springframework.test.bean.override.BeanOverrideStrategy; +import org.springframework.test.bean.override.OverrideMetadata; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * Simple {@link BeanOverrideProcessor} primarily made to work with the + * {@link TestBean} annotation but can work with arbitrary override annotations + * provided the annotated class has a relevant method according to the + * convention documented in {@link TestBean}. + * + * @author Simon Baslé + * @since 6.2 + */ +public class TestBeanOverrideProcessor implements BeanOverrideProcessor { + + /** + * Ensures the {@code enclosingClass} has a static, no-arguments method with + * the provided {@code expectedMethodReturnType} and exactly one of the + * {@code expectedMethodNames}. + */ + public static Method ensureMethod(Class enclosingClass, Class expectedMethodReturnType, + String... expectedMethodNames) { + Assert.isTrue(expectedMethodNames.length > 0, "At least one expectedMethodName is required"); + Set expectedNames = new LinkedHashSet<>(Arrays.asList(expectedMethodNames)); + final List found = Arrays.stream(enclosingClass.getDeclaredMethods()) + .filter(m -> Modifier.isStatic(m.getModifiers())) + .filter(m -> expectedNames.contains(m.getName()) && expectedMethodReturnType + .isAssignableFrom(m.getReturnType())) + .collect(Collectors.toList()); + + Assert.state(found.size() == 1, () -> "Found " + found.size() + " static methods " + + "instead of exactly one, matching a name in " + expectedNames + " with return type " + + expectedMethodReturnType.getName() + " on class " + enclosingClass.getName()); + + return found.get(0); + } + + @Override + public OverrideMetadata createMetadata(Field field, Annotation overrideAnnotation, ResolvableType typeToOverride) { + final Class enclosingClass = field.getDeclaringClass(); + // if we can get an explicit method name right away, fail fast if it doesn't match + if (overrideAnnotation instanceof TestBean testBeanAnnotation) { + Method overrideMethod = null; + String beanName = null; + if (!testBeanAnnotation.methodName().isBlank()) { + overrideMethod = ensureMethod(enclosingClass, field.getType(), testBeanAnnotation.methodName()); + } + if (!testBeanAnnotation.name().isBlank()) { + beanName = testBeanAnnotation.name(); + } + return new MethodConventionOverrideMetadata(field, overrideMethod, beanName, + overrideAnnotation, typeToOverride); + } + // otherwise defer the resolution of the static method until OverrideMetadata#createOverride + return new MethodConventionOverrideMetadata(field, null, null, overrideAnnotation, + typeToOverride); + } + + static final class MethodConventionOverrideMetadata extends OverrideMetadata { + + @Nullable + private final Method overrideMethod; + + @Nullable + private final String beanName; + + public MethodConventionOverrideMetadata(Field field, @Nullable Method overrideMethod, @Nullable String beanName, + Annotation overrideAnnotation, ResolvableType typeToOverride) { + super(field, overrideAnnotation, typeToOverride, BeanOverrideStrategy.REPLACE_DEFINITION); + this.overrideMethod = overrideMethod; + this.beanName = beanName; + } + + @Override + protected String getExpectedBeanName() { + if (StringUtils.hasText(this.beanName)) { + return this.beanName; + } + return super.getExpectedBeanName(); + } + + @Override + public String getBeanOverrideDescription() { + return "method convention"; + } + + @Override + protected Object createOverride(String beanName, @Nullable BeanDefinition existingBeanDefinition, + @Nullable Object existingBeanInstance) { + Method methodToInvoke = this.overrideMethod; + if (methodToInvoke == null) { + methodToInvoke = ensureMethod(field().getDeclaringClass(), field().getType(), + beanName + TestBean.CONVENTION_SUFFIX, + field().getName() + TestBean.CONVENTION_SUFFIX); + } + + methodToInvoke.setAccessible(true); + Object override; + try { + override = methodToInvoke.invoke(null); + } + catch (IllegalAccessException | InvocationTargetException ex) { + throw new IllegalArgumentException("Could not invoke bean overriding method " + methodToInvoke.getName() + + ", a static method with no input parameters is expected", ex); + } + + return override; + } + } + +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/convention/package-info.java b/spring-test/src/main/java/org/springframework/test/bean/override/convention/package-info.java new file mode 100644 index 00000000000..2173d679955 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/convention/package-info.java @@ -0,0 +1,11 @@ +/** + * Bean override mechanism based on conventionally-named static methods + * in the test class. This allows defining a custom instance for the bean + * straight from the test class. + */ +@NonNullApi +@NonNullFields +package org.springframework.test.bean.override.convention; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/mockito/Definition.java b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/Definition.java new file mode 100644 index 00000000000..57ad9c26e83 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/Definition.java @@ -0,0 +1,118 @@ +/* + * Copyright 2012-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.mockito; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; + +import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.beans.factory.config.SingletonBeanRegistry; +import org.springframework.core.ResolvableType; +import org.springframework.lang.Nullable; +import org.springframework.test.bean.override.BeanOverrideStrategy; +import org.springframework.test.bean.override.OverrideMetadata; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +/** + * Base class for {@link MockDefinition} and {@link SpyDefinition}. + * + * @author Phillip Webb + */ +abstract class Definition extends OverrideMetadata { + + static final int MULTIPLIER = 31; + + protected final String name; + + private final MockReset reset; + + private final boolean proxyTargetAware; + + Definition(String name, @Nullable MockReset reset, boolean proxyTargetAware, Field field, + Annotation annotation, ResolvableType typeToOverride, BeanOverrideStrategy strategy) { + super(field, annotation, typeToOverride, strategy); + this.name = name; + this.reset = (reset != null) ? reset : MockReset.AFTER; + this.proxyTargetAware = proxyTargetAware; + } + + @Override + protected String getExpectedBeanName() { + if (StringUtils.hasText(this.name)) { + return this.name; + } + return super.getExpectedBeanName(); + } + + @Override + protected void track(Object mock, SingletonBeanRegistry trackingBeanRegistry) { + MockitoBeans tracker = null; + try { + tracker = (MockitoBeans) trackingBeanRegistry.getSingleton(MockitoBeans.class.getName()); + } + catch (NoSuchBeanDefinitionException ignored) { + + } + if (tracker == null) { + tracker= new MockitoBeans(); + trackingBeanRegistry.registerSingleton(MockitoBeans.class.getName(), tracker); + } + tracker.add(mock); + } + + /** + * Return the mock reset mode. + * @return the reset mode + */ + MockReset getReset() { + return this.reset; + } + + /** + * Return if AOP advised beans should be proxy target aware. + * @return if proxy target aware + */ + boolean isProxyTargetAware() { + return this.proxyTargetAware; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj == null || !getClass().isAssignableFrom(obj.getClass())) { + return false; + } + Definition other = (Definition) obj; + boolean result = ObjectUtils.nullSafeEquals(this.name, other.name); + result = result && ObjectUtils.nullSafeEquals(this.reset, other.reset); + result = result && ObjectUtils.nullSafeEquals(this.proxyTargetAware, other.proxyTargetAware); + return result; + } + + @Override + public int hashCode() { + int result = 1; + result = MULTIPLIER * result + ObjectUtils.nullSafeHashCode(this.name); + result = MULTIPLIER * result + ObjectUtils.nullSafeHashCode(this.reset); + result = MULTIPLIER * result + ObjectUtils.nullSafeHashCode(this.proxyTargetAware); + return result; + } + +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockDefinition.java b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockDefinition.java new file mode 100644 index 00000000000..05fea839495 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockDefinition.java @@ -0,0 +1,170 @@ +/* + * Copyright 2012-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.mockito; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; + +import org.mockito.Answers; +import org.mockito.MockSettings; + +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.core.ResolvableType; +import org.springframework.core.style.ToStringCreator; +import org.springframework.lang.Nullable; +import org.springframework.test.bean.override.BeanOverrideStrategy; +import org.springframework.util.Assert; +import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +import static org.mockito.Mockito.mock; + +/** + * A complete definition that can be used to create a Mockito mock. + * + * @author Phillip Webb + */ +class MockDefinition extends Definition { + + private static final int MULTIPLIER = 31; + + private final Set> extraInterfaces; + + private final Answers answer; + + private final boolean serializable; + + MockDefinition(MockitoBean annotation, Field field, ResolvableType typeToMock) { + this(annotation.name(), annotation.reset(), field, annotation, typeToMock, + annotation.extraInterfaces(), annotation.answers(), annotation.serializable()); + } + + MockDefinition(String name, MockReset reset, Field field, Annotation annotation, ResolvableType typeToMock, + Class[] extraInterfaces, @Nullable Answers answer, boolean serializable) { + super(name, reset, false, field, annotation, typeToMock, BeanOverrideStrategy.REPLACE_OR_CREATE_DEFINITION); + Assert.notNull(typeToMock, "TypeToMock must not be null"); + this.extraInterfaces = asClassSet(extraInterfaces); + this.answer = (answer != null) ? answer : Answers.RETURNS_DEFAULTS; + this.serializable = serializable; + } + + @Override + public String getBeanOverrideDescription() { + return "mock"; + } + + @Override + protected Object createOverride(String beanName, BeanDefinition existingBeanDefinition, Object existingBeanInstance) { + return createMock(beanName); + } + + private Set> asClassSet(Class[] classes) { + Set> classSet = new LinkedHashSet<>(); + if (classes != null) { + classSet.addAll(Arrays.asList(classes)); + } + return Collections.unmodifiableSet(classSet); + } + + /** + * Return the extra interfaces. + * @return the extra interfaces or an empty set + */ + Set> getExtraInterfaces() { + return this.extraInterfaces; + } + + /** + * Return the answers mode. + * @return the answers mode; never {@code null} + */ + Answers getAnswer() { + return this.answer; + } + + /** + * Return if the mock is serializable. + * @return if the mock is serializable + */ + boolean isSerializable() { + return this.serializable; + } + + @Override + public boolean equals(@Nullable Object obj) { + if (obj == this) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + MockDefinition other = (MockDefinition) obj; + boolean result = super.equals(obj); + result = result && ObjectUtils.nullSafeEquals(this.typeToOverride(), other.typeToOverride()); + result = result && ObjectUtils.nullSafeEquals(this.extraInterfaces, other.extraInterfaces); + result = result && ObjectUtils.nullSafeEquals(this.answer, other.answer); + result = result && this.serializable == other.serializable; + return result; + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = MULTIPLIER * result + ObjectUtils.nullSafeHashCode(this.typeToOverride()); + result = MULTIPLIER * result + ObjectUtils.nullSafeHashCode(this.extraInterfaces); + result = MULTIPLIER * result + ObjectUtils.nullSafeHashCode(this.answer); + result = MULTIPLIER * result + Boolean.hashCode(this.serializable); + return result; + } + + @Override + public String toString() { + return new ToStringCreator(this).append("name", this.name) + .append("typeToMock", this.typeToOverride()) + .append("extraInterfaces", this.extraInterfaces) + .append("answer", this.answer) + .append("serializable", this.serializable) + .append("reset", getReset()) + .toString(); + } + + T createMock() { + return createMock(this.name); + } + + @SuppressWarnings("unchecked") + T createMock(String name) { + MockSettings settings = MockReset.withSettings(getReset()); + if (StringUtils.hasLength(name)) { + settings.name(name); + } + if (!this.extraInterfaces.isEmpty()) { + settings.extraInterfaces(ClassUtils.toClassArray(this.extraInterfaces)); + } + settings.defaultAnswer(this.answer); + if (this.serializable) { + settings.serializable(); + } + return (T) mock(this.typeToOverride().resolve(), settings); + } + +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockReset.java b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockReset.java new file mode 100644 index 00000000000..b2184b1381e --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockReset.java @@ -0,0 +1,139 @@ +/* + * Copyright 2012-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.mockito; + +import java.util.List; + +import org.mockito.MockSettings; +import org.mockito.MockingDetails; +import org.mockito.Mockito; +import org.mockito.listeners.InvocationListener; +import org.mockito.listeners.MethodInvocationReport; +import org.mockito.mock.MockCreationSettings; + +import org.springframework.util.Assert; + +/** + * Reset strategy used on a mock bean. Usually applied to a mock through the + * {@link MockitoBean @MockitoBean} annotation but can also be directly applied to any mock in + * the {@code ApplicationContext} using the static methods. + * + * @author Phillip Webb + * @since 1.4.0 + * @see MockitoResetTestExecutionListener + */ +public enum MockReset { + + /** + * Reset the mock before the test method runs. + */ + BEFORE, + + /** + * Reset the mock after the test method runs. + */ + AFTER, + + /** + * Don't reset the mock. + */ + NONE; + + /** + * Create {@link MockSettings settings} to be used with mocks where reset should occur + * before each test method runs. + * @return mock settings + */ + public static MockSettings before() { + return withSettings(BEFORE); + } + + /** + * Create {@link MockSettings settings} to be used with mocks where reset should occur + * after each test method runs. + * @return mock settings + */ + public static MockSettings after() { + return withSettings(AFTER); + } + + /** + * Create {@link MockSettings settings} to be used with mocks where a specific reset + * should occur. + * @param reset the reset type + * @return mock settings + */ + public static MockSettings withSettings(MockReset reset) { + return apply(reset, Mockito.withSettings()); + } + + /** + * Apply {@link MockReset} to existing {@link MockSettings settings}. + * @param reset the reset type + * @param settings the settings + * @return the configured settings + */ + public static MockSettings apply(MockReset reset, MockSettings settings) { + Assert.notNull(settings, "Settings must not be null"); + if (reset != null && reset != NONE) { + settings.invocationListeners(new ResetInvocationListener(reset)); + } + return settings; + } + + /** + * Get the {@link MockReset} associated with the given mock. + * @param mock the source mock + * @return the reset type (never {@code null}) + */ + static MockReset get(Object mock) { + MockReset reset = MockReset.NONE; + MockingDetails mockingDetails = Mockito.mockingDetails(mock); + if (mockingDetails.isMock()) { + MockCreationSettings settings = mockingDetails.getMockCreationSettings(); + List listeners = settings.getInvocationListeners(); + for (Object listener : listeners) { + if (listener instanceof ResetInvocationListener resetInvocationListener) { + reset = resetInvocationListener.getReset(); + } + } + } + return reset; + } + + /** + * Dummy {@link InvocationListener} used to hold the {@link MockReset} value. + */ + private static class ResetInvocationListener implements InvocationListener { + + private final MockReset reset; + + ResetInvocationListener(MockReset reset) { + this.reset = reset; + } + + MockReset getReset() { + return this.reset; + } + + @Override + public void reportInvocation(MethodInvocationReport methodInvocationReport) { + } + + } + +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoBean.java b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoBean.java new file mode 100644 index 00000000000..ec33e57ec68 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoBean.java @@ -0,0 +1,86 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.mockito; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.mockito.Answers; +import org.mockito.MockSettings; + +import org.springframework.test.bean.override.BeanOverride; + +/** + * Mark a field to trigger a bean override using a Mockito mock. If no explicit + * {@link #name()} is specified, the annotated field's name is interpreted to + * be the target of the override. In either case, if no existing bean is defined + * a new one will be added to the context. In order to ensure mocks are set up + * and reset correctly, the test class must itself be annotated with + * {@link MockitoBeanOverrideTestListeners}. + * + *

Dependencies that are known to the application context but are not beans + * (such as those {@link org.springframework.beans.factory.config.ConfigurableListableBeanFactory#registerResolvableDependency(Class, Object) + * registered directly}) will not be found and a mocked bean will be added to + * the context alongside the existing dependency. + * + * @author Simon Baslé + * @since 6.2 + */ +@Target(ElementType.FIELD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@BeanOverride(MockitoBeanOverrideProcessor.class) +public @interface MockitoBean { + + /** + * The name of the bean to register or replace. If not specified, it will be + * the name of the annotated field. + * @return the name of the bean + */ + String name() default ""; + + /** + * Any extra interfaces that should also be declared on the mock. See + * {@link MockSettings#extraInterfaces(Class...)} for details. + * @return any extra interfaces + */ + Class[] extraInterfaces() default {}; + + /** + * The {@link Answers} type to use on the mock. + * @return the answer type + */ + Answers answers() default Answers.RETURNS_DEFAULTS; + + /** + * If the generated mock is serializable. See {@link MockSettings#serializable()} for + * details. + * @return if the mock is serializable + */ + boolean serializable() default false; + + /** + * The reset mode to apply to the mock bean. The default is {@link MockReset#AFTER} + * meaning that mocks are automatically reset after each test method is invoked. + * @return the reset mode + */ + MockReset reset() default MockReset.AFTER; + +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoBeanOverrideProcessor.java b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoBeanOverrideProcessor.java new file mode 100644 index 00000000000..d74b132122c --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoBeanOverrideProcessor.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.mockito; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; + +import org.springframework.core.ResolvableType; +import org.springframework.test.bean.override.BeanOverrideProcessor; +import org.springframework.test.bean.override.OverrideMetadata; + +public class MockitoBeanOverrideProcessor implements BeanOverrideProcessor { + + public OverrideMetadata createMetadata(Field field, Annotation overrideAnnotation, ResolvableType typeToMock) { + if (overrideAnnotation instanceof MockitoBean mockBean) { + return new MockDefinition(mockBean, field, typeToMock); + } + else if (overrideAnnotation instanceof MockitoSpyBean spyBean) { + return new SpyDefinition(spyBean, field, typeToMock); + } + throw new IllegalArgumentException("Invalid annotation for MockitoBeanOverrideProcessor: " + overrideAnnotation.getClass().getName()); + } + +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoBeans.java b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoBeans.java new file mode 100644 index 00000000000..9431b4e872d --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoBeans.java @@ -0,0 +1,41 @@ +/* + * Copyright 2012-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.mockito; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +/** + * Beans created using Mockito. + * + * @author Andy Wilkinson + */ +class MockitoBeans implements Iterable { + + private final List beans = new ArrayList<>(); + + void add(Object bean) { + this.beans.add(bean); + } + + @Override + public Iterator iterator() { + return this.beans.iterator(); + } + +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoResetTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoResetTestExecutionListener.java new file mode 100644 index 00000000000..e5983760789 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoResetTestExecutionListener.java @@ -0,0 +1,126 @@ +/* + * Copyright 2012-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.mockito; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import org.mockito.Mockito; + +import org.springframework.beans.factory.BeanFactory; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.NoSuchBeanDefinitionException; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.core.NativeDetector; +import org.springframework.core.Ordered; +import org.springframework.test.context.TestContext; +import org.springframework.test.context.TestExecutionListener; +import org.springframework.test.context.support.AbstractTestExecutionListener; + +/** + * {@link TestExecutionListener} to reset any mock beans that have been marked with a + * {@link MockReset}. Typically used alongside {@link MockitoTestExecutionListener}. + * + * @author Phillip Webb + * @since 6.2 + * @see MockitoTestExecutionListener + */ +public class MockitoResetTestExecutionListener extends AbstractTestExecutionListener { + + /** + * Executes before {@link org.springframework.test.bean.override.BeanOverrideTestExecutionListener}. + */ + @Override + public int getOrder() { + return Ordered.LOWEST_PRECEDENCE - 100; + } + + @Override + public void beforeTestMethod(TestContext testContext) throws Exception { + if (MockitoTestExecutionListener.mockitoPresent && !NativeDetector.inNativeImage()) { + resetMocks(testContext.getApplicationContext(), MockReset.BEFORE); + } + } + + @Override + public void afterTestMethod(TestContext testContext) throws Exception { + if (MockitoTestExecutionListener.mockitoPresent && !NativeDetector.inNativeImage()) { + resetMocks(testContext.getApplicationContext(), MockReset.AFTER); + } + } + + private void resetMocks(ApplicationContext applicationContext, MockReset reset) { + if (applicationContext instanceof ConfigurableApplicationContext configurableContext) { + resetMocks(configurableContext, reset); + } + } + + private void resetMocks(ConfigurableApplicationContext applicationContext, MockReset reset) { + ConfigurableListableBeanFactory beanFactory = applicationContext.getBeanFactory(); + String[] names = beanFactory.getBeanDefinitionNames(); + Set instantiatedSingletons = new HashSet<>(Arrays.asList(beanFactory.getSingletonNames())); + for (String name : names) { + BeanDefinition definition = beanFactory.getBeanDefinition(name); + if (definition.isSingleton() && instantiatedSingletons.contains(name)) { + Object bean = getBean(beanFactory, name); + if (bean != null && reset.equals(MockReset.get(bean))) { + Mockito.reset(bean); + } + } + } + try { + MockitoBeans mockedBeans = beanFactory.getBean(MockitoBeans.class); + for (Object mockedBean : mockedBeans) { + if (reset.equals(MockReset.get(mockedBean))) { + Mockito.reset(mockedBean); + } + } + } + catch (NoSuchBeanDefinitionException ex) { + // Continue + } + if (applicationContext.getParent() != null) { + resetMocks(applicationContext.getParent(), reset); + } + } + + private Object getBean(ConfigurableListableBeanFactory beanFactory, String name) { + try { + if (isStandardBeanOrSingletonFactoryBean(beanFactory, name)) { + return beanFactory.getBean(name); + } + } + catch (Exception ex) { + // Continue + } + return beanFactory.getSingleton(name); + } + + private boolean isStandardBeanOrSingletonFactoryBean(ConfigurableListableBeanFactory beanFactory, String name) { + String factoryBeanName = BeanFactory.FACTORY_BEAN_PREFIX + name; + if (beanFactory.containsBean(factoryBeanName)) { + FactoryBean factoryBean = (FactoryBean) beanFactory.getBean(factoryBeanName); + return factoryBean.isSingleton(); + } + return true; + } + +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoSpyBean.java b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoSpyBean.java new file mode 100644 index 00000000000..360d2c22cf4 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoSpyBean.java @@ -0,0 +1,84 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.mockito; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.mockito.Mockito; + +import org.springframework.test.bean.override.BeanOverride; + +/** + * Mark a field to trigger the override of the bean of the same name with a + * Mockito spy, which will wrap the original instance. + * In order to ensure mocks are set up and reset correctly, the test class must + * itself be annotated with {@link MockitoBeanOverrideTestListeners}. + * + * @author Simon Baslé + * @since 6.2 + */ +/** + * Mark a field to trigger a bean override using a Mockito spy, which will wrap + * the original instance. If no explicit {@link #name()} is specified, the + * annotated field's name is interpreted to be the target of the override. + * In either case, it is required that the target bean is previously registered + * in the context. In order to ensure spies are set up and reset correctly, + * the test class must itself be annotated with {@link MockitoBeanOverrideTestListeners}. + * + *

Dependencies that are known to the application context but are not beans + * (such as those {@link org.springframework.beans.factory.config.ConfigurableListableBeanFactory#registerResolvableDependency(Class, Object) + * registered directly}) will not be found. + * + * @author Simon Baslé + * @since 6.2 + */ +@Target(ElementType.FIELD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@BeanOverride(MockitoBeanOverrideProcessor.class) +public @interface MockitoSpyBean { + + /** + * The name of the bean to spy. If not specified, it will be the name of the + * annotated field. + * @return the name of the spied bean + */ + String name() default ""; + + /** + * The reset mode to apply to the spied bean. The default is {@link MockReset#AFTER} + * meaning that spies are automatically reset after each test method is invoked. + * @return the reset mode + */ + MockReset reset() default MockReset.AFTER; + + /** + * Indicates that Mockito methods such as {@link Mockito#verify(Object) verify(mock)} + * should use the {@code target} of AOP advised beans, rather than the proxy itself. + * If set to {@code false} you may need to use the result of + * {@link org.springframework.test.util.AopTestUtils#getUltimateTargetObject(Object) + * AopTestUtils.getUltimateTargetObject(...)} when calling Mockito methods. + * @return {@code true} if the target of AOP advised beans is used or {@code false} if + * the proxy is used directly + */ + boolean proxyTargetAware() default true; + +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoTestExecutionListener.java new file mode 100644 index 00000000000..2252ecfe87c --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/MockitoTestExecutionListener.java @@ -0,0 +1,139 @@ +/* + * Copyright 2012-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.mockito; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; +import java.util.LinkedHashSet; +import java.util.Set; + +import org.mockito.Captor; +import org.mockito.MockitoAnnotations; + +import org.springframework.test.context.TestContext; +import org.springframework.test.context.TestExecutionListener; +import org.springframework.test.context.support.AbstractTestExecutionListener; +import org.springframework.test.context.support.DependencyInjectionTestExecutionListener; +import org.springframework.util.ClassUtils; +import org.springframework.util.ReflectionUtils; +import org.springframework.util.ReflectionUtils.FieldCallback; + +/** + * {@link TestExecutionListener} to enable {@link MockitoBean @MockitoBean} and + * {@link MockitoSpyBean @MockitoSpyBean} support. Also triggers + * {@link MockitoAnnotations#openMocks(Object)} when any Mockito annotations used, + * primarily to allow {@link Captor @Captor} annotations. + *

+ * The automatic reset support of {@code @MockBean} and {@code @SpyBean} is + * handled by sibling {@link MockitoResetTestExecutionListener}. + * + * @author Simon Baslé + * @author Phillip Webb + * @author Andy Wilkinson + * @author Moritz Halbritter + * @since 1.4.2 + * @see MockitoResetTestExecutionListener + */ +public class MockitoTestExecutionListener extends AbstractTestExecutionListener { + + static final boolean mockitoPresent = ClassUtils.isPresent("org.mockito.MockSettings", + MockitoTestExecutionListener.class.getClassLoader()); + + private static final String MOCKS_ATTRIBUTE_NAME = MockitoTestExecutionListener.class.getName() + ".mocks"; + + /** + * Executes before {@link DependencyInjectionTestExecutionListener}. + */ + @Override + public final int getOrder() { + return 1950; + } + + @Override + public void prepareTestInstance(TestContext testContext) throws Exception { + if (mockitoPresent) { + closeMocks(testContext); + initMocks(testContext); + } + } + + @Override + public void beforeTestMethod(TestContext testContext) throws Exception { + if (mockitoPresent && Boolean.TRUE.equals( + testContext.getAttribute(DependencyInjectionTestExecutionListener.REINJECT_DEPENDENCIES_ATTRIBUTE))) { + closeMocks(testContext); + initMocks(testContext); + } + } + + @Override + public void afterTestMethod(TestContext testContext) throws Exception { + if (mockitoPresent) { + closeMocks(testContext); + } + } + + @Override + public void afterTestClass(TestContext testContext) throws Exception { + if (mockitoPresent) { + closeMocks(testContext); + } + } + + private void initMocks(TestContext testContext) { + if (hasMockitoAnnotations(testContext)) { + Object testInstance = testContext.getTestInstance(); + testContext.setAttribute(MOCKS_ATTRIBUTE_NAME, MockitoAnnotations.openMocks(testInstance)); + } + } + + private void closeMocks(TestContext testContext) throws Exception { + Object mocks = testContext.getAttribute(MOCKS_ATTRIBUTE_NAME); + if (mocks instanceof AutoCloseable closeable) { + closeable.close(); + } + } + + private boolean hasMockitoAnnotations(TestContext testContext) { + MockitoAnnotationCollection collector = new MockitoAnnotationCollection(); + ReflectionUtils.doWithFields(testContext.getTestClass(), collector); + return collector.hasAnnotations(); + } + + /** + * {@link FieldCallback} to collect Mockito annotations. + */ + private static final class MockitoAnnotationCollection implements FieldCallback { + + private final Set annotations = new LinkedHashSet<>(); + + @Override + public void doWith(Field field) throws IllegalArgumentException { + for (Annotation annotation : field.getDeclaredAnnotations()) { + if (annotation.annotationType().getName().startsWith("org.mockito")) { + this.annotations.add(annotation); + } + } + } + + boolean hasAnnotations() { + return !this.annotations.isEmpty(); + } + + } + +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/mockito/SpyDefinition.java b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/SpyDefinition.java new file mode 100644 index 00000000000..acc2e047151 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/SpyDefinition.java @@ -0,0 +1,145 @@ +/* + * Copyright 2012-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.mockito; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; +import java.lang.reflect.Proxy; +import java.util.Objects; + +import org.mockito.AdditionalAnswers; +import org.mockito.MockSettings; +import org.mockito.Mockito; +import org.mockito.listeners.VerificationStartedEvent; +import org.mockito.listeners.VerificationStartedListener; + +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.core.ResolvableType; +import org.springframework.core.style.ToStringCreator; +import org.springframework.lang.Nullable; +import org.springframework.test.bean.override.BeanOverrideStrategy; +import org.springframework.test.util.AopTestUtils; +import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; +import org.springframework.util.StringUtils; + +import static org.mockito.Mockito.mock; + +/** + * A complete definition that can be used to create a Mockito spy. + * + * @author Phillip Webb + */ +class SpyDefinition extends Definition { + + SpyDefinition(MockitoSpyBean spyAnnotation, Field field, ResolvableType typeToSpy) { + this(spyAnnotation.name(), spyAnnotation.reset(), spyAnnotation.proxyTargetAware(), field, + spyAnnotation, typeToSpy); + } + + SpyDefinition(String name, MockReset reset, boolean proxyTargetAware, Field field, Annotation annotation, + ResolvableType typeToSpy) { + super(name, reset, proxyTargetAware, field, annotation, typeToSpy, BeanOverrideStrategy.WRAP_EARLY_BEAN); + Assert.notNull(typeToSpy, "typeToSpy must not be null"); + } + + @Override + public String getBeanOverrideDescription() { + return "spy"; + } + + @Override + protected Object createOverride(String beanName, @Nullable BeanDefinition existingBeanDefinition, @Nullable Object existingBeanInstance) { + return createSpy(beanName, Objects.requireNonNull(existingBeanInstance, + "MockitoSpyBean requires an existing bean instance for bean " + beanName)); + } + + @Override + public boolean equals(@Nullable Object obj) { + //for SpyBean we want the class to be exactly the same + if (obj == this) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + SpyDefinition other = (SpyDefinition) obj; + boolean result = super.equals(obj); + result = result && ObjectUtils.nullSafeEquals(this.typeToOverride(), other.typeToOverride()); + return result; + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = MULTIPLIER * result + ObjectUtils.nullSafeHashCode(this.typeToOverride()); + return result; + } + + @Override + public String toString() { + return new ToStringCreator(this).append("name", this.name) + .append("typeToSpy", typeToOverride()) + .append("reset", getReset()) + .toString(); + } + + T createSpy(Object instance) { + return createSpy(this.name, instance); + } + + @SuppressWarnings("unchecked") + T createSpy(String name, Object instance) { + Assert.notNull(instance, "Instance must not be null"); + Assert.isInstanceOf(Objects.requireNonNull(this.typeToOverride().resolve()), instance); + if (Mockito.mockingDetails(instance).isSpy()) { + return (T) instance; + } + MockSettings settings = MockReset.withSettings(getReset()); + if (StringUtils.hasLength(name)) { + settings.name(name); + } + if (isProxyTargetAware()) { + settings.verificationStartedListeners(new SpringAopBypassingVerificationStartedListener()); + } + Class toSpy; + if (Proxy.isProxyClass(instance.getClass())) { + settings.defaultAnswer(AdditionalAnswers.delegatesTo(instance)); + toSpy = this.typeToOverride().toClass(); + } + else { + settings.defaultAnswer(Mockito.CALLS_REAL_METHODS); + settings.spiedInstance(instance); + toSpy = instance.getClass(); + } + return (T) mock(toSpy, settings); + } + + /** + * A {@link VerificationStartedListener} that bypasses any proxy created by Spring AOP + * when the verification of a spy starts. + */ + private static final class SpringAopBypassingVerificationStartedListener implements VerificationStartedListener { + + @Override + public void onVerificationStarted(VerificationStartedEvent event) { + event.setMock(AopTestUtils.getUltimateTargetObject(event.getMock())); + } + + } + +} diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/mockito/package-info.java b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/package-info.java new file mode 100644 index 00000000000..4072e97cd71 --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/mockito/package-info.java @@ -0,0 +1,9 @@ +/** + * Support case-by-case Bean overriding in Spring tests. + */ +@NonNullApi +@NonNullFields +package org.springframework.test.bean.override.mockito; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-test/src/main/java/org/springframework/test/bean/override/package-info.java b/spring-test/src/main/java/org/springframework/test/bean/override/package-info.java new file mode 100644 index 00000000000..567521dac4a --- /dev/null +++ b/spring-test/src/main/java/org/springframework/test/bean/override/package-info.java @@ -0,0 +1,9 @@ +/** + * Support case-by-case Bean overriding in Spring tests. + */ +@NonNullApi +@NonNullFields +package org.springframework.test.bean.override; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-test/src/main/resources/META-INF/spring.factories b/spring-test/src/main/resources/META-INF/spring.factories index 2b9e4e3c116..ad629a95e92 100644 --- a/spring-test/src/main/resources/META-INF/spring.factories +++ b/spring-test/src/main/resources/META-INF/spring.factories @@ -1,6 +1,9 @@ # Default TestExecutionListeners for the Spring TestContext Framework # org.springframework.test.context.TestExecutionListener = \ + org.springframework.test.bean.override.BeanOverrideTestExecutionListener,\ + org.springframework.test.bean.override.mockito.MockitoTestExecutionListener,\ + org.springframework.test.bean.override.mockito.MockitoResetTestExecutionListener,\ org.springframework.test.context.web.ServletTestExecutionListener,\ org.springframework.test.context.support.DirtiesContextBeforeModesTestExecutionListener,\ org.springframework.test.context.event.ApplicationEventsTestExecutionListener,\ @@ -14,5 +17,6 @@ org.springframework.test.context.TestExecutionListener = \ # Default ContextCustomizerFactory implementations for the Spring TestContext Framework # org.springframework.test.context.ContextCustomizerFactory = \ + org.springframework.test.bean.override.BeanOverrideContextCustomizerFactory,\ org.springframework.test.context.web.socket.MockServerContainerContextCustomizerFactory,\ org.springframework.test.context.support.DynamicPropertiesContextCustomizerFactory diff --git a/spring-test/src/test/java/org/springframework/test/bean/override/BeanOverrideBeanPostProcessorTests.java b/spring-test/src/test/java/org/springframework/test/bean/override/BeanOverrideBeanPostProcessorTests.java new file mode 100644 index 00000000000..0a75cc9abf5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/bean/override/BeanOverrideBeanPostProcessorTests.java @@ -0,0 +1,328 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override; + +import java.util.Map; +import java.util.function.Predicate; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.beans.BeanWrapper; +import org.springframework.beans.factory.FactoryBean; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.support.SimpleThreadScope; +import org.springframework.core.Ordered; +import org.springframework.core.ResolvableType; +import org.springframework.test.bean.override.example.ExampleBeanOverrideAnnotation; +import org.springframework.test.bean.override.example.ExampleService; +import org.springframework.test.bean.override.example.FailingExampleService; +import org.springframework.test.bean.override.example.RealExampleService; +import org.springframework.test.util.ReflectionTestUtils; +import org.springframework.util.Assert; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.assertj.core.api.Assertions.assertThatNoException; + +/** + * Test for {@link BeanOverrideBeanPostProcessor}. + * + * @author Simon Baslé + */ +class BeanOverrideBeanPostProcessorTests { + + BeanOverrideParser parser; + + @BeforeEach + void initParser() { + this.parser = new BeanOverrideParser(); + } + + @Test + void canReplaceExistingBeanDefinitions() { + this.parser.parse(ReplaceBeans.class); + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + BeanOverrideBeanPostProcessor.register(context, this.parser.getOverrideMetadata()); + context.register(ReplaceBeans.class); + context.registerBean("explicit", ExampleService.class, () -> new RealExampleService("unexpected")); + context.registerBean("implicitName", ExampleService.class, () -> new RealExampleService("unexpected")); + + context.refresh(); + + assertThat(context.getBean("explicit")).isSameAs(OVERRIDE_SERVICE); + assertThat(context.getBean("implicitName")).isSameAs(OVERRIDE_SERVICE); + } + + @Test + void cannotReplaceIfNoBeanMatching() { + this.parser.parse(ReplaceBeans.class); + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + BeanOverrideBeanPostProcessor.register(context, this.parser.getOverrideMetadata()); + context.register(ReplaceBeans.class); + //note we don't register any original bean here + + assertThatIllegalStateException().isThrownBy(context::refresh).withMessage("Unable to override test bean, " + + "expected a bean definition to replace with name 'explicit'"); + } + + @Test + void canReplaceExistingBeanDefinitionsWithCreateReplaceStrategy() { + this.parser.parse(CreateIfOriginalIsMissingBean.class); + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + BeanOverrideBeanPostProcessor.register(context, this.parser.getOverrideMetadata()); + context.register(CreateIfOriginalIsMissingBean.class); + context.registerBean("explicit", ExampleService.class, () -> new RealExampleService("unexpected")); + context.registerBean("implicitName", ExampleService.class, () -> new RealExampleService("unexpected")); + + context.refresh(); + + assertThat(context.getBean("explicit")).isSameAs(OVERRIDE_SERVICE); + assertThat(context.getBean("implicitName")).isSameAs(OVERRIDE_SERVICE); + } + + @Test + void canCreateIfOriginalMissingWithCreateReplaceStrategy() { + this.parser.parse(CreateIfOriginalIsMissingBean.class); + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + BeanOverrideBeanPostProcessor.register(context, this.parser.getOverrideMetadata()); + context.register(CreateIfOriginalIsMissingBean.class); + //note we don't register original beans here + + context.refresh(); + + assertThat(context.getBean("explicit")).isSameAs(OVERRIDE_SERVICE); + assertThat(context.getBean("implicitName")).isSameAs(OVERRIDE_SERVICE); + } + + @Test + void canOverrideBeanProducedByFactoryBeanWithClassObjectTypeAttribute() { + this.parser.parse(OverriddenFactoryBean.class); + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + BeanOverrideBeanPostProcessor.register(context, parser.getOverrideMetadata()); + RootBeanDefinition factoryBeanDefinition = new RootBeanDefinition(TestFactoryBean.class); + factoryBeanDefinition.setAttribute(FactoryBean.OBJECT_TYPE_ATTRIBUTE, SomeInterface.class); + context.registerBeanDefinition("beanToBeOverridden", factoryBeanDefinition); + context.register(OverriddenFactoryBean.class); + context.refresh(); + assertThat(context.getBean("beanToBeOverridden")).isSameAs(OVERRIDE); + } + + @Test + void canOverrideBeanProducedByFactoryBeanWithResolvableTypeObjectTypeAttribute() { + this.parser.parse(OverriddenFactoryBean.class); + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + BeanOverrideBeanPostProcessor.register(context, parser.getOverrideMetadata()); + RootBeanDefinition factoryBeanDefinition = new RootBeanDefinition(TestFactoryBean.class); + ResolvableType objectType = ResolvableType.forClass(SomeInterface.class); + factoryBeanDefinition.setAttribute(FactoryBean.OBJECT_TYPE_ATTRIBUTE, objectType); + context.registerBeanDefinition("beanToBeOverridden", factoryBeanDefinition); + context.register(OverriddenFactoryBean.class); + context.refresh(); + assertThat(context.getBean("beanToBeOverridden")).isSameAs(OVERRIDE); + } + + + @Test + void postProcessorShouldNotTriggerEarlyInitialization() { + this.parser.parse(EagerInitBean.class); + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.register(FactoryBeanRegisteringPostProcessor.class); + BeanOverrideBeanPostProcessor.register(context, parser.getOverrideMetadata()); + context.register(EarlyBeanInitializationDetector.class); + context.register(EagerInitBean.class); + + assertThatNoException().isThrownBy(context::refresh); + } + + @Test + void allowReplaceDefinitionWhenSingletonDefinitionPresent() { + this.parser.parse(SingletonBean.class); + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + RootBeanDefinition definition = new RootBeanDefinition(String.class, () -> "ORIGINAL"); + definition.setScope(BeanDefinition.SCOPE_SINGLETON); + context.registerBeanDefinition("singleton", definition); + BeanOverrideBeanPostProcessor.register(context, this.parser.getOverrideMetadata()); + context.register(SingletonBean.class); + + assertThatNoException().isThrownBy(context::refresh); + assertThat(context.isSingleton("singleton")).as("isSingleton").isTrue(); + assertThat(context.getBean("singleton")).as("overridden").isEqualTo("USED THIS"); + } + + @Test + void copyDefinitionPrimaryAndScope() { + this.parser.parse(SingletonBean.class); + AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); + context.getBeanFactory().registerScope("customScope", new SimpleThreadScope()); + RootBeanDefinition definition = new RootBeanDefinition(String.class, () -> "ORIGINAL"); + definition.setScope("customScope"); + definition.setPrimary(true); + context.registerBeanDefinition("singleton", definition); + BeanOverrideBeanPostProcessor.register(context, this.parser.getOverrideMetadata()); + context.register(SingletonBean.class); + + assertThatNoException().isThrownBy(context::refresh); + assertThat(context.getBeanDefinition("singleton")) + .isNotSameAs(definition) + .matches(BeanDefinition::isPrimary, "isPrimary") + .satisfies(d -> assertThat(d.getScope()).isEqualTo("customScope")) + .matches(Predicate.not(BeanDefinition::isSingleton), "!isSingleton") + .matches(Predicate.not(BeanDefinition::isPrototype), "!isPrototype"); + } + + /* + Classes to parse and register with the bean post processor + ----- + Note that some of these are both a @Configuration class and bean override field holder. + This is for this test convenience, as typically the bean override annotated fields + should not be in configuration classes but rather in test case classes + (where a TestExecutionListener automatically discovers and parses them). + */ + + static final SomeInterface OVERRIDE = new SomeImplementation(); + static final ExampleService OVERRIDE_SERVICE = new FailingExampleService(); + + static class ReplaceBeans { + + @ExampleBeanOverrideAnnotation(value = "useThis", beanName = "explicit") + private ExampleService explicitName; + + @ExampleBeanOverrideAnnotation(value = "useThis") + private ExampleService implicitName; + + static ExampleService useThis() { + return OVERRIDE_SERVICE; + } + } + + static class CreateIfOriginalIsMissingBean { + + @ExampleBeanOverrideAnnotation(value = "useThis", createIfMissing = true, beanName = "explicit") + private ExampleService explicitName; + + @ExampleBeanOverrideAnnotation(value = "useThis", createIfMissing = true) + private ExampleService implicitName; + + static ExampleService useThis() { + return OVERRIDE_SERVICE; + } + + } + + @Configuration(proxyBeanMethods = false) + static class OverriddenFactoryBean { + + @ExampleBeanOverrideAnnotation(value = "fOverride", beanName = "beanToBeOverridden") + SomeInterface f; + + static SomeInterface fOverride() { + return OVERRIDE; + } + + @Bean + TestFactoryBean testFactoryBean() { + return new TestFactoryBean(); + } + + } + + static class EagerInitBean { + + @ExampleBeanOverrideAnnotation(value = "useThis", createIfMissing = true) + private ExampleService service; + + static ExampleService useThis() { + return OVERRIDE_SERVICE; + } + + } + + static class SingletonBean { + + @ExampleBeanOverrideAnnotation(beanName = "singleton", + value = "useThis", createIfMissing = false) + private String value; + + static String useThis() { + return "USED THIS"; + } + + } + + static class TestFactoryBean implements FactoryBean { + + @Override + public Object getObject() { + return new SomeImplementation(); + } + + @Override + public Class getObjectType() { + return null; + } + + @Override + public boolean isSingleton() { + return true; + } + + } + + static class FactoryBeanRegisteringPostProcessor implements BeanFactoryPostProcessor, Ordered { + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) { + RootBeanDefinition beanDefinition = new RootBeanDefinition(TestFactoryBean.class); + ((BeanDefinitionRegistry) beanFactory).registerBeanDefinition("test", beanDefinition); + } + + @Override + public int getOrder() { + return Ordered.HIGHEST_PRECEDENCE; + } + + } + + static class EarlyBeanInitializationDetector implements BeanFactoryPostProcessor { + + @Override + @SuppressWarnings("unchecked") + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) { + Map cache = (Map) ReflectionTestUtils.getField(beanFactory, + "factoryBeanInstanceCache"); + Assert.isTrue(cache.isEmpty(), "Early initialization of factory bean triggered."); + } + + } + + interface SomeInterface { + + } + + static class SomeImplementation implements SomeInterface { + + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/bean/override/BeanOverrideParserTests.java b/spring-test/src/test/java/org/springframework/test/bean/override/BeanOverrideParserTests.java new file mode 100644 index 00000000000..ffa913510b0 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/bean/override/BeanOverrideParserTests.java @@ -0,0 +1,122 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override; + +import org.junit.jupiter.api.Test; + +import org.springframework.context.annotation.Configuration; +import org.springframework.test.bean.override.example.ExampleBeanOverrideAnnotation; +import org.springframework.test.bean.override.example.TestBeanOverrideMetaAnnotation; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatRuntimeException; +import static org.springframework.test.bean.override.example.ExampleBeanOverrideProcessor.DUPLICATE_TRIGGER; + +class BeanOverrideParserTests { + + @Test + void findsOnField() { + BeanOverrideParser parser = new BeanOverrideParser(); + parser.parse(OnFieldConf.class); + + assertThat(parser.getOverrideMetadata()).hasSize(1) + .first() + .extracting(om -> ((ExampleBeanOverrideAnnotation) om.overrideAnnotation()).value()) + .isEqualTo("onField"); + } + + @Test + void allowMultipleProcessorsOnDifferentElements() { + BeanOverrideParser parser = new BeanOverrideParser(); + parser.parse(MultipleFieldsWithOnFieldConf.class); + + assertThat(parser.getOverrideMetadata()) + .hasSize(2) + .map(om -> ((ExampleBeanOverrideAnnotation) om.overrideAnnotation()).value()) + .containsOnly("onField1", "onField2"); + } + + @Test + void rejectsMultipleAnnotationsOnSameElement() { + BeanOverrideParser parser = new BeanOverrideParser(); + assertThatRuntimeException().isThrownBy(() -> parser.parse(MultipleOnFieldConf.class)) + .withMessage("Multiple bean override annotations found on annotated field <" + + String.class.getName() + " " + MultipleOnFieldConf.class.getName() + ".message>"); + } + + @Test + void detectsDuplicateMetadata() { + BeanOverrideParser parser = new BeanOverrideParser(); + assertThatRuntimeException().isThrownBy(() -> parser.parse(DuplicateConf.class)) + .withMessage("Duplicate test overrideMetadata {DUPLICATE_TRIGGER}"); + } + + + @Configuration + static class OnFieldConf { + + @ExampleBeanOverrideAnnotation("onField") + String message; + + static String onField() { + return "OK"; + } + + } + + @Configuration + static class MultipleOnFieldConf { + + @ExampleBeanOverrideAnnotation("foo") + @TestBeanOverrideMetaAnnotation + String message; + + static String foo() { + return "foo"; + } + + } + + @Configuration + static class MultipleFieldsWithOnFieldConf { + @ExampleBeanOverrideAnnotation("onField1") + String message; + + @ExampleBeanOverrideAnnotation("onField2") + String messageOther; + + static String onField1() { + return "OK1"; + } + + static String onField2() { + return "OK2"; + } + } + + @Configuration + static class DuplicateConf { + + @ExampleBeanOverrideAnnotation(DUPLICATE_TRIGGER) + String message1; + + @ExampleBeanOverrideAnnotation(DUPLICATE_TRIGGER) + String message2; + + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/bean/override/OverrideMetadataTests.java b/spring-test/src/test/java/org/springframework/test/bean/override/OverrideMetadataTests.java new file mode 100644 index 00000000000..1a011a6a455 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/bean/override/OverrideMetadataTests.java @@ -0,0 +1,68 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; +import java.util.Objects; + +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.core.ResolvableType; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +import static org.assertj.core.api.Assertions.assertThat; + +class OverrideMetadataTests { + + static class ConcreteOverrideMetadata extends OverrideMetadata { + + ConcreteOverrideMetadata(Field field, Annotation overrideAnnotation, ResolvableType typeToOverride, + BeanOverrideStrategy strategy) { + super(field, overrideAnnotation, typeToOverride, strategy); + } + + @Override + public String getBeanOverrideDescription() { + return ConcreteOverrideMetadata.class.getSimpleName(); + } + + @Override + protected Object createOverride(String beanName, @Nullable BeanDefinition existingBeanDefinition, @Nullable Object existingBeanInstance) { + return BeanOverrideStrategy.REPLACE_DEFINITION; + } + } + + @NonNull + public String annotated = "exampleField"; + + static OverrideMetadata exampleOverride() throws NoSuchFieldException { + final Field annotated = OverrideMetadataTests.class.getField("annotated"); + return new ConcreteOverrideMetadata(Objects.requireNonNull(annotated), annotated.getAnnotation(NonNull.class), + ResolvableType.forClass(String.class), BeanOverrideStrategy.REPLACE_DEFINITION); + } + + @Test + void implicitConfigurations() throws NoSuchFieldException { + final OverrideMetadata metadata = exampleOverride(); + assertThat(metadata.getExpectedBeanName()).as("expectedBeanName") + .isEqualTo(metadata.field().getName()); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/bean/override/convention/TestBeanOverrideProcessorTests.java b/spring-test/src/test/java/org/springframework/test/bean/override/convention/TestBeanOverrideProcessorTests.java new file mode 100644 index 00000000000..b2b63706a82 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/bean/override/convention/TestBeanOverrideProcessorTests.java @@ -0,0 +1,130 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.convention; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.Objects; + +import org.junit.jupiter.api.Test; + +import org.springframework.context.annotation.Bean; +import org.springframework.core.ResolvableType; +import org.springframework.core.annotation.AnnotationUtils; +import org.springframework.test.bean.override.example.ExampleService; +import org.springframework.test.bean.override.example.FailingExampleService; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatException; + +class TestBeanOverrideProcessorTests { + + @Test + void ensureMethodFindsFromList() { + Method m = TestBeanOverrideProcessor.ensureMethod(MethodConventionConf.class, ExampleService.class, + "example1", "example2", "example3"); + + assertThat(m.getName()).isEqualTo("example2"); + } + + @Test + void ensureMethodNotFound() { + assertThatException().isThrownBy(() -> TestBeanOverrideProcessor.ensureMethod( + MethodConventionConf.class, ExampleService.class, "example1", "example3")) + .withMessage("Found 0 static methods instead of exactly one, matching a name in [example1, example3] with return type " + + ExampleService.class.getName() + " on class " + MethodConventionConf.class.getName()) + .isInstanceOf(IllegalStateException.class); + } + + @Test + void ensureMethodTwoFound() { + assertThatException().isThrownBy(() -> TestBeanOverrideProcessor.ensureMethod( + MethodConventionConf.class, ExampleService.class, "example2", "example4")) + .withMessage("Found 2 static methods instead of exactly one, matching a name in [example2, example4] with return type " + + ExampleService.class.getName() + " on class " + MethodConventionConf.class.getName()) + .isInstanceOf(IllegalStateException.class); + } + + @Test + void ensureMethodNoNameProvided() { + assertThatException().isThrownBy(() -> TestBeanOverrideProcessor.ensureMethod( + MethodConventionConf.class, ExampleService.class)) + .withMessage("At least one expectedMethodName is required") + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void createMetaDataForUnknownExplicitMethod() throws NoSuchFieldException { + Field f = ExplicitMethodNameConf.class.getField("a"); + final TestBean overrideAnnotation = Objects.requireNonNull(AnnotationUtils.getAnnotation(f, TestBean.class)); + TestBeanOverrideProcessor processor = new TestBeanOverrideProcessor(); + assertThatException().isThrownBy(() -> processor.createMetadata(f, overrideAnnotation, ResolvableType.forClass(ExampleService.class))) + .withMessage("Found 0 static methods instead of exactly one, matching a name in [explicit1] with return type " + + ExampleService.class.getName() + " on class " + ExplicitMethodNameConf.class.getName()) + .isInstanceOf(IllegalStateException.class); + } + + @Test + void createMetaDataForKnownExplicitMethod() throws NoSuchFieldException { + Field f = ExplicitMethodNameConf.class.getField("b"); + final TestBean overrideAnnotation = Objects.requireNonNull(AnnotationUtils.getAnnotation(f, TestBean.class)); + TestBeanOverrideProcessor processor = new TestBeanOverrideProcessor(); + assertThat(processor.createMetadata(f, overrideAnnotation, ResolvableType.forClass(ExampleService.class))) + .isInstanceOf(TestBeanOverrideProcessor.MethodConventionOverrideMetadata.class); + } + + @Test + void createMetaDataWithDeferredEnsureMethodCheck() throws NoSuchFieldException { + Field f = MethodConventionConf.class.getField("field"); + final TestBean overrideAnnotation = Objects.requireNonNull(AnnotationUtils.getAnnotation(f, TestBean.class)); + TestBeanOverrideProcessor processor = new TestBeanOverrideProcessor(); + assertThat(processor.createMetadata(f, overrideAnnotation, ResolvableType.forClass(ExampleService.class))) + .isInstanceOf(TestBeanOverrideProcessor.MethodConventionOverrideMetadata.class); + } + + static class MethodConventionConf { + + @TestBean + public ExampleService field; + + @Bean + ExampleService example1() { + return new FailingExampleService(); + } + + static ExampleService example2() { + return new FailingExampleService(); + } + + public static ExampleService example4() { + return new FailingExampleService(); + } + } + + static class ExplicitMethodNameConf { + + @TestBean(methodName = "explicit1") + public ExampleService a; + + @TestBean(methodName = "explicit2") + public ExampleService b; + + static ExampleService explicit2() { + return new FailingExampleService(); + } + } +} diff --git a/spring-test/src/test/java/org/springframework/test/bean/override/example/ExampleBeanOverrideAnnotation.java b/spring-test/src/test/java/org/springframework/test/bean/override/example/ExampleBeanOverrideAnnotation.java new file mode 100644 index 00000000000..d6474ccdcca --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/bean/override/example/ExampleBeanOverrideAnnotation.java @@ -0,0 +1,38 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.example; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.test.bean.override.BeanOverride; + +@BeanOverride(ExampleBeanOverrideProcessor.class) +@Target({ElementType.FIELD, ElementType.ANNOTATION_TYPE}) +@Retention(RetentionPolicy.RUNTIME) +public @interface ExampleBeanOverrideAnnotation { + + static final String DEFAULT_VALUE = "TEST OVERRIDE"; + + String value() default DEFAULT_VALUE; + + boolean createIfMissing() default false; + + String beanName() default ""; +} diff --git a/spring-test/src/test/java/org/springframework/test/bean/override/example/ExampleBeanOverrideProcessor.java b/spring-test/src/test/java/org/springframework/test/bean/override/example/ExampleBeanOverrideProcessor.java new file mode 100644 index 00000000000..6df216f4fe3 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/bean/override/example/ExampleBeanOverrideProcessor.java @@ -0,0 +1,49 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.example; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Field; + +import org.springframework.core.ResolvableType; +import org.springframework.test.bean.override.BeanOverrideProcessor; +import org.springframework.test.bean.override.OverrideMetadata; + +public class ExampleBeanOverrideProcessor implements BeanOverrideProcessor { + + public ExampleBeanOverrideProcessor() { + } + + private static final TestOverrideMetadata CONSTANT = new TestOverrideMetadata() { + @Override + public String toString() { + return "{DUPLICATE_TRIGGER}"; + } + }; + public static final String DUPLICATE_TRIGGER = "CONSTANT"; + + @Override + public OverrideMetadata createMetadata(Field field, Annotation overrideAnnotation, ResolvableType typeToOverride) { + if (!(overrideAnnotation instanceof ExampleBeanOverrideAnnotation annotation)) { + throw new IllegalStateException("unexpected annotation"); + } + if (annotation.value().equals(DUPLICATE_TRIGGER)) { + return CONSTANT; + } + return new TestOverrideMetadata(field, annotation, typeToOverride); + } +} diff --git a/spring-test/src/test/java/org/springframework/test/bean/override/example/ExampleService.java b/spring-test/src/test/java/org/springframework/test/bean/override/example/ExampleService.java new file mode 100644 index 00000000000..272d42956c5 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/bean/override/example/ExampleService.java @@ -0,0 +1,28 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.example; + +/** + * Example service interface for mocking tests. + * + * @author Phillip Webb + */ +public interface ExampleService { + + String greeting(); + +} diff --git a/spring-test/src/test/java/org/springframework/test/bean/override/example/FailingExampleService.java b/spring-test/src/test/java/org/springframework/test/bean/override/example/FailingExampleService.java new file mode 100644 index 00000000000..786b29de65b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/bean/override/example/FailingExampleService.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.example; + +import org.springframework.stereotype.Service; + +/** + * An {@link ExampleService} that always throws an exception. + * + * @author Phillip Webb + */ +@Service +public class FailingExampleService implements ExampleService { + + @Override + public String greeting() { + throw new IllegalStateException("Failed"); + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/bean/override/example/RealExampleService.java b/spring-test/src/test/java/org/springframework/test/bean/override/example/RealExampleService.java new file mode 100644 index 00000000000..df0f1f070c2 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/bean/override/example/RealExampleService.java @@ -0,0 +1,37 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.example; + +/** + * Example service implementation for spy tests. + * + * @author Phillip Webb + */ +public class RealExampleService implements ExampleService { + + private final String greeting; + + public RealExampleService(String greeting) { + this.greeting = greeting; + } + + @Override + public String greeting() { + return this.greeting; + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/bean/override/example/TestBeanOverrideMetaAnnotation.java b/spring-test/src/test/java/org/springframework/test/bean/override/example/TestBeanOverrideMetaAnnotation.java new file mode 100644 index 00000000000..4a6af18901a --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/bean/override/example/TestBeanOverrideMetaAnnotation.java @@ -0,0 +1,27 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.example; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target({ElementType.FIELD, ElementType.TYPE}) +@Retention(RetentionPolicy.RUNTIME) +@ExampleBeanOverrideAnnotation("foo") +public @interface TestBeanOverrideMetaAnnotation { } diff --git a/spring-test/src/test/java/org/springframework/test/bean/override/example/TestOverrideMetadata.java b/spring-test/src/test/java/org/springframework/test/bean/override/example/TestOverrideMetadata.java new file mode 100644 index 00000000000..4af81e4293b --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/bean/override/example/TestOverrideMetadata.java @@ -0,0 +1,119 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.bean.override.example; + +import java.lang.reflect.AnnotatedElement; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; + +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.core.ResolvableType; +import org.springframework.lang.Nullable; +import org.springframework.test.bean.override.BeanOverrideStrategy; +import org.springframework.test.bean.override.OverrideMetadata; +import org.springframework.util.StringUtils; + +import static org.springframework.test.bean.override.example.ExampleBeanOverrideAnnotation.DEFAULT_VALUE; + +public class TestOverrideMetadata extends OverrideMetadata { + + @Nullable + private final Method method; + + @Nullable + private final String beanName; + + @Nullable + private static Method findMethod(AnnotatedElement element, String methodName) { + if (DEFAULT_VALUE.equals(methodName)) { + return null; + } + if (element instanceof Field f) { + for (Method m : f.getDeclaringClass().getDeclaredMethods()) { + if (!Modifier.isStatic(m.getModifiers())) { + continue; + } + if (m.getName().equals(methodName)) { + return m; + } + } + throw new IllegalStateException("Expected a static method named <" + methodName + "> alongside annotated field <" + f.getName() + ">"); + } + if (element instanceof Method m) { + if (m.getName().equals(methodName) && Modifier.isStatic(m.getModifiers())) { + return m; + } + throw new IllegalStateException("Expected the annotated method to be static and named <" + methodName + ">"); + } + if (element instanceof Class c) { + for (Method m : c.getDeclaredMethods()) { + if (!Modifier.isStatic(m.getModifiers())) { + continue; + } + if (m.getName().equals(methodName)) { + return m; + } + } + throw new IllegalStateException("Expected a static method named <" + methodName + "> on annotated class <" + c.getSimpleName() + ">"); + } + throw new IllegalStateException("Expected the annotated element to be a Field, Method or Class"); + } + + public TestOverrideMetadata(Field field, ExampleBeanOverrideAnnotation overrideAnnotation, ResolvableType typeToOverride) { + super(field, overrideAnnotation, typeToOverride, overrideAnnotation.createIfMissing() ? + BeanOverrideStrategy.REPLACE_OR_CREATE_DEFINITION: BeanOverrideStrategy.REPLACE_DEFINITION); + this.method = findMethod(field, overrideAnnotation.value()); + this.beanName = overrideAnnotation.beanName(); + } + + //Used to trigger duplicate detection in parser test + TestOverrideMetadata() { + super(null, null, null, null); + this.method = null; + this.beanName = null; + } + + @Override + protected String getExpectedBeanName() { + if (StringUtils.hasText(this.beanName)) { + return this.beanName; + } + return super.getExpectedBeanName(); + } + + @Override + public String getBeanOverrideDescription() { + return "test"; + } + + @Override + protected Object createOverride(String beanName, @Nullable BeanDefinition existingBeanDefinition, @Nullable Object existingBeanInstance) { + if (this.method == null) { + return DEFAULT_VALUE; + } + try { + this.method.setAccessible(true); + return this.method.invoke(null); + } + catch (IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException(e); + } + } + +} diff --git a/spring-test/src/test/java/org/springframework/test/bean/override/example/package-info.java b/spring-test/src/test/java/org/springframework/test/bean/override/example/package-info.java new file mode 100644 index 00000000000..699aba48693 --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/bean/override/example/package-info.java @@ -0,0 +1,9 @@ +/** + * Example components for testing spring-test Bean overriding feature. + */ +@NonNullApi +@NonNullFields +package org.springframework.test.bean.override.example; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-test/src/test/java/org/springframework/test/context/TestExecutionListenersTests.java b/spring-test/src/test/java/org/springframework/test/context/TestExecutionListenersTests.java index 595208e0596..d69b97501e0 100644 --- a/spring-test/src/test/java/org/springframework/test/context/TestExecutionListenersTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/TestExecutionListenersTests.java @@ -25,6 +25,9 @@ import org.junit.jupiter.api.Test; import org.springframework.core.Ordered; import org.springframework.core.annotation.AliasFor; import org.springframework.core.annotation.AnnotationConfigurationException; +import org.springframework.test.bean.override.BeanOverrideTestExecutionListener; +import org.springframework.test.bean.override.mockito.MockitoResetTestExecutionListener; +import org.springframework.test.bean.override.mockito.MockitoTestExecutionListener; import org.springframework.test.context.event.ApplicationEventsTestExecutionListener; import org.springframework.test.context.event.EventPublishingTestExecutionListener; import org.springframework.test.context.jdbc.SqlScriptsTestExecutionListener; @@ -65,12 +68,15 @@ class TestExecutionListenersTests { List> expected = asList(ServletTestExecutionListener.class,// DirtiesContextBeforeModesTestExecutionListener.class,// ApplicationEventsTestExecutionListener.class,// + MockitoTestExecutionListener.class,// DependencyInjectionTestExecutionListener.class,// micrometerListenerClass,// DirtiesContextTestExecutionListener.class,// TransactionalTestExecutionListener.class,// SqlScriptsTestExecutionListener.class,// - EventPublishingTestExecutionListener.class + EventPublishingTestExecutionListener.class,// + MockitoResetTestExecutionListener.class,// + BeanOverrideTestExecutionListener.class ); assertRegisteredListeners(DefaultListenersTestCase.class, expected); } @@ -84,12 +90,15 @@ class TestExecutionListenersTests { ServletTestExecutionListener.class,// DirtiesContextBeforeModesTestExecutionListener.class,// ApplicationEventsTestExecutionListener.class,// + MockitoTestExecutionListener.class,// DependencyInjectionTestExecutionListener.class,// micrometerListenerClass,// DirtiesContextTestExecutionListener.class,// TransactionalTestExecutionListener.class,// SqlScriptsTestExecutionListener.class,// - EventPublishingTestExecutionListener.class + EventPublishingTestExecutionListener.class,// + MockitoResetTestExecutionListener.class,// + BeanOverrideTestExecutionListener.class ); assertRegisteredListeners(MergedDefaultListenersWithCustomListenerPrependedTestCase.class, expected); } @@ -102,12 +111,15 @@ class TestExecutionListenersTests { List> expected = asList(ServletTestExecutionListener.class,// DirtiesContextBeforeModesTestExecutionListener.class,// ApplicationEventsTestExecutionListener.class,// + MockitoTestExecutionListener.class,// DependencyInjectionTestExecutionListener.class,// micrometerListenerClass,// DirtiesContextTestExecutionListener.class,// TransactionalTestExecutionListener.class, SqlScriptsTestExecutionListener.class,// EventPublishingTestExecutionListener.class,// + MockitoResetTestExecutionListener.class,// + BeanOverrideTestExecutionListener.class,// BazTestExecutionListener.class ); assertRegisteredListeners(MergedDefaultListenersWithCustomListenerAppendedTestCase.class, expected); @@ -121,13 +133,16 @@ class TestExecutionListenersTests { List> expected = asList(ServletTestExecutionListener.class,// DirtiesContextBeforeModesTestExecutionListener.class,// ApplicationEventsTestExecutionListener.class,// + MockitoTestExecutionListener.class,// DependencyInjectionTestExecutionListener.class,// BarTestExecutionListener.class,// micrometerListenerClass,// DirtiesContextTestExecutionListener.class,// TransactionalTestExecutionListener.class,// SqlScriptsTestExecutionListener.class,// - EventPublishingTestExecutionListener.class + EventPublishingTestExecutionListener.class,// + MockitoResetTestExecutionListener.class,// + BeanOverrideTestExecutionListener.class ); assertRegisteredListeners(MergedDefaultListenersWithCustomListenerInsertedTestCase.class, expected); }