From 61c9cbc3f5d2fb97751991bd6bfa8530a3acc6dc Mon Sep 17 00:00:00 2001 From: Stephane Nicoll Date: Mon, 12 Jun 2023 16:33:35 +0200 Subject: [PATCH] Retain active profiles used during AOT processing This commit makes sure that profiles that have been explicitly enabled during AOT optimizations are automatically enabled when using those optimizations. If other profiles are set at runtime, they take precedence over the ones defined during AOT processing. Closes gh-30421 --- .../aot/ApplicationContextAotGenerator.java | 6 ++-- ...ionContextInitializationCodeGenerator.java | 23 ++++++++++-- .../ApplicationContextAotGeneratorTests.java | 35 +++++++++++++++++++ 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java index 7e97b686254..0331b145c10 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextAotGenerator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,9 +52,9 @@ public class ApplicationContextAotGenerator { GenerationContext generationContext) { return withCglibClassHandler(new CglibClassHandler(generationContext), () -> { applicationContext.refreshForAotProcessing(generationContext.getRuntimeHints()); - DefaultListableBeanFactory beanFactory = applicationContext.getDefaultListableBeanFactory(); ApplicationContextInitializationCodeGenerator codeGenerator = - new ApplicationContextInitializationCodeGenerator(generationContext); + new ApplicationContextInitializationCodeGenerator(applicationContext, generationContext); + DefaultListableBeanFactory beanFactory = applicationContext.getDefaultListableBeanFactory(); new BeanFactoryInitializationAotContributions(beanFactory).applyTo(generationContext, codeGenerator); return codeGenerator.getGeneratedClass().getName(); }); diff --git a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java index 3a7253b0440..5305508b9da 100644 --- a/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java +++ b/spring-context/src/main/java/org/springframework/context/aot/ApplicationContextInitializationCodeGenerator.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.context.aot; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.function.Function; @@ -49,6 +50,7 @@ import org.springframework.lang.Nullable; * Internal code generator to create the {@link ApplicationContextInitializer}. * * @author Phillip Webb + * @author Stephane Nicoll * @since 6.0 */ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitializationCode { @@ -57,12 +59,15 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia private static final String APPLICATION_CONTEXT_VARIABLE = "applicationContext"; - private final List initializers = new ArrayList<>(); + private final GenericApplicationContext applicationContext; private final GeneratedClass generatedClass; + private final List initializers = new ArrayList<>(); + - ApplicationContextInitializationCodeGenerator(GenerationContext generationContext) { + ApplicationContextInitializationCodeGenerator(GenericApplicationContext applicationContext, GenerationContext generationContext) { + this.applicationContext = applicationContext; this.generatedClass = generationContext.getGeneratedClasses() .addForFeature("ApplicationContextInitializer", this::generateType); this.generatedClass.reserveMethodNames(INITIALIZE_METHOD); @@ -97,6 +102,7 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia BEAN_FACTORY_VARIABLE, ContextAnnotationAutowireCandidateResolver.class); code.addStatement("$L.setDependencyComparator($T.INSTANCE)", BEAN_FACTORY_VARIABLE, AnnotationAwareOrderComparator.class); + code.add(generateActiveProfilesInitializeCode()); ArgumentCodeGenerator argCodeGenerator = createInitializerMethodArgumentCodeGenerator(); for (MethodReference initializer : this.initializers) { code.addStatement(initializer.toInvokeCodeBlock(argCodeGenerator, this.generatedClass.getName())); @@ -104,6 +110,17 @@ class ApplicationContextInitializationCodeGenerator implements BeanFactoryInitia return code.build(); } + private CodeBlock generateActiveProfilesInitializeCode() { + CodeBlock.Builder code = CodeBlock.builder(); + ConfigurableEnvironment environment = this.applicationContext.getEnvironment(); + if (!Arrays.equals(environment.getActiveProfiles(), environment.getDefaultProfiles())) { + for (String activeProfile : environment.getActiveProfiles()) { + code.addStatement("$L.getEnvironment().addActiveProfile($S)", APPLICATION_CONTEXT_VARIABLE, activeProfile); + } + } + return code.build(); + } + static ArgumentCodeGenerator createInitializerMethodArgumentCodeGenerator() { return ArgumentCodeGenerator.from(new InitializerMethodArgumentCodeGenerator()); } diff --git a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java index 01a210cd0d5..a2ab23b8ade 100644 --- a/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java +++ b/spring-context/src/test/java/org/springframework/context/aot/ApplicationContextAotGeneratorTests.java @@ -22,9 +22,13 @@ import java.lang.reflect.Proxy; import java.util.List; import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.stream.Stream; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.springframework.aot.generate.GeneratedFiles.Kind; import org.springframework.aot.generate.GenerationContext; @@ -384,6 +388,37 @@ class ApplicationContextAotGeneratorTests { } + @Nested + static class ActiveProfile { + + @ParameterizedTest + @MethodSource("activeProfilesParameters") + void processAheadOfTimeWhenHasActiveProfiles(String[] aotProfiles, String[] runtimeProfiles, String[] expectedActiveProfiles) { + GenericApplicationContext applicationContext = new GenericApplicationContext(); + if (aotProfiles.length != 0) { + applicationContext.getEnvironment().setActiveProfiles(aotProfiles); + } + testCompiledResult(applicationContext, (initializer, compiled) -> { + GenericApplicationContext freshApplicationContext = new GenericApplicationContext(); + if (runtimeProfiles.length != 0) { + freshApplicationContext.getEnvironment().setActiveProfiles(runtimeProfiles); + } + initializer.initialize(freshApplicationContext); + freshApplicationContext.refresh(); + assertThat(freshApplicationContext.getEnvironment().getActiveProfiles()).containsExactly(expectedActiveProfiles); + }); + } + + static Stream activeProfilesParameters() { + return Stream.of(Arguments.of(new String[] { "aot", "prod" }, new String[] {}, new String[] { "aot", "prod" }), + Arguments.of(new String[] {}, new String[] { "aot", "prod" }, new String[] { "aot", "prod" }), + Arguments.of(new String[] { "aot" }, new String[] { "prod" }, new String[] { "prod", "aot" }), + Arguments.of(new String[] { "aot", "prod" }, new String[] { "aot", "prod" }, new String[] { "aot", "prod" }), + Arguments.of(new String[] { "default" }, new String[] {}, new String[] {})); + } + + } + private Consumer> doesNotHaveProxyFor(Class target) { return hints -> assertThat(hints).noneMatch(hint -> hint.getProxiedInterfaces().get(0).equals(TypeReference.of(target)));