diff --git a/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideBeanFactoryPostProcessor.java b/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideBeanFactoryPostProcessor.java index 755217cd699..32a25f9f6e6 100644 --- a/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideBeanFactoryPostProcessor.java +++ b/spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideBeanFactoryPostProcessor.java @@ -128,7 +128,11 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor, if (beanName != null) { // We are overriding an existing bean by-type. beanName = BeanFactoryUtils.transformedBeanName(beanName); - existingBeanDefinition = beanFactory.getBeanDefinition(beanName); + // If we are overriding a manually registered singleton, we won't find + // an existing bean definition. + if (beanFactory.containsBeanDefinition(beanName)) { + existingBeanDefinition = beanFactory.getBeanDefinition(beanName); + } } else { // We will later generate a name for the nonexistent bean, but since NullAway @@ -150,6 +154,12 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor, } } + // Ensure we don't have any manually registered singletons registered, since we + // register a bean override instance as a manual singleton at the end of this method. + if (beanFactory.containsSingleton(beanName)) { + destroySingleton(beanFactory, beanName); + } + if (existingBeanDefinition != null) { // Validate the existing bean definition. // @@ -333,6 +343,10 @@ class BeanOverrideBeanFactoryPostProcessor implements BeanFactoryPostProcessor, // Since the isSingleton() check above may have registered a singleton as a side // effect -- for example, for a FactoryBean -- we need to destroy the singleton, // because we later manually register a bean override instance as a singleton. + destroySingleton(beanFactory, beanName); + } + + private static void destroySingleton(ConfigurableListableBeanFactory beanFactory, String beanName) { if (beanFactory instanceof DefaultListableBeanFactory dlbf) { dlbf.destroySingleton(beanName); } diff --git a/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanManuallyRegisteredSingletonTests.java b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanManuallyRegisteredSingletonTests.java new file mode 100644 index 00000000000..4cb2d8cfecd --- /dev/null +++ b/spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/MockitoBeanManuallyRegisteredSingletonTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.test.context.bean.override.mockito; + +import org.junit.jupiter.api.Test; + +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.test.context.bean.override.mockito.MockitoBeanManuallyRegisteredSingletonTests.SingletonRegistrar; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.when; + +/** + * Verifies support for overriding a manually registered singleton bean with + * {@link MockitoBean @MockitoBean}. + * + * @author Andy Wilkinson + * @author Sam Brannen + * @since 6.2 + */ +@SpringJUnitConfig(initializers = SingletonRegistrar.class) +class MockitoBeanManuallyRegisteredSingletonTests { + + @MockitoBean + MessageService messageService; + + @Test + void test() { + when(messageService.getMessage()).thenReturn("override"); + assertThat(messageService.getMessage()).isEqualTo("override"); + } + + static class SingletonRegistrar implements ApplicationContextInitializer { + + @Override + public void initialize(ConfigurableApplicationContext applicationContext) { + applicationContext.getBeanFactory().registerSingleton("messageService", new MessageService()); + } + } + + static class MessageService { + + String getMessage() { + return "production"; + } + } + +}