diff --git a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/PrePostReactiveMethodSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/PrePostReactiveMethodSecurityConfigurationTests.java index fed48d08f4..e1eea9d52c 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/method/configuration/PrePostReactiveMethodSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/method/configuration/PrePostReactiveMethodSecurityConfigurationTests.java @@ -18,6 +18,10 @@ package org.springframework.security.config.annotation.method.configuration; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -34,12 +38,18 @@ import org.springframework.context.annotation.Role; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.access.PermissionEvaluator; import org.springframework.security.access.expression.method.DefaultMethodSecurityExpressionHandler; +import org.springframework.security.access.expression.method.MethodSecurityExpressionHandler; +import org.springframework.security.access.hierarchicalroles.RoleHierarchy; +import org.springframework.security.access.hierarchicalroles.RoleHierarchyImpl; import org.springframework.security.access.prepost.PostAuthorize; import org.springframework.security.access.prepost.PostFilter; import org.springframework.security.access.prepost.PreAuthorize; import org.springframework.security.access.prepost.PreFilter; import org.springframework.security.authorization.AuthorizationDeniedException; +import org.springframework.security.authorization.method.AuthorizationAdvisorProxyFactory; +import org.springframework.security.authorization.method.AuthorizeReturnObject; import org.springframework.security.authorization.method.PrePostTemplateDefaults; +import org.springframework.security.config.Customizer; import org.springframework.security.config.test.SpringTestContext; import org.springframework.security.config.test.SpringTestContextExtension; import org.springframework.security.core.annotation.AnnotationTemplateExpressionDefaults; @@ -49,6 +59,7 @@ import org.springframework.test.context.junit.jupiter.SpringExtension; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatNoException; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; @@ -320,6 +331,84 @@ public class PrePostReactiveMethodSecurityConfigurationTests { .containsExactly("dave"); } + @Test + @WithMockUser(authorities = "airplane:read") + public void findByIdWhenAuthorizedResultThenAuthorizes() { + this.spring.register(AuthorizeResultConfig.class).autowire(); + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); + Flight flight = flights.findById("1").block(); + assertThatNoException().isThrownBy(flight::getAltitude); + assertThatNoException().isThrownBy(flight::getSeats); + } + + @Test + @WithMockUser(authorities = "seating:read") + public void findByIdWhenUnauthorizedResultThenDenies() { + this.spring.register(AuthorizeResultConfig.class).autowire(); + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); + Flight flight = flights.findById("1").block(); + assertThatNoException().isThrownBy(flight::getSeats); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> flight.getAltitude().block()); + } + + @Test + @WithMockUser(authorities = "seating:read") + public void findAllWhenUnauthorizedResultThenDenies() { + this.spring.register(AuthorizeResultConfig.class).autowire(); + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); + flights.findAll().collectList().block().forEach((flight) -> { + assertThatNoException().isThrownBy(flight::getSeats); + assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> flight.getAltitude().block()); + }); + } + + @Test + public void removeWhenAuthorizedResultThenRemoves() { + this.spring.register(AuthorizeResultConfig.class).autowire(); + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); + flights.remove("1"); + } + + @Test + @WithMockUser(authorities = "airplane:read") + public void findAllWhenPostFilterThenFilters() { + this.spring.register(AuthorizeResultConfig.class).autowire(); + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); + flights.findAll() + .collectList() + .block() + .forEach((flight) -> assertThat(flight.getPassengers().collectList().block()) + .extracting((p) -> p.getName().block()) + .doesNotContain("Kevin Mitnick")); + } + + @Test + @WithMockUser(authorities = "airplane:read") + public void findAllWhenPreFilterThenFilters() { + this.spring.register(AuthorizeResultConfig.class).autowire(); + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); + flights.findAll().collectList().block().forEach((flight) -> { + flight.board(Flux.just("John")).block(); + assertThat(flight.getPassengers().collectList().block()).extracting((p) -> p.getName().block()) + .doesNotContain("John"); + flight.board(Flux.just("John Doe")).block(); + assertThat(flight.getPassengers().collectList().block()).extracting((p) -> p.getName().block()) + .contains("John Doe"); + }); + } + + @Test + @WithMockUser(authorities = "seating:read") + public void findAllWhenNestedPreAuthorizeThenAuthorizes() { + this.spring.register(AuthorizeResultConfig.class).autowire(); + FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class); + flights.findAll().collectList().block().forEach((flight) -> { + List passengers = flight.getPassengers().collectList().block(); + passengers.forEach((passenger) -> assertThatExceptionOfType(AccessDeniedException.class) + .isThrownBy(() -> passenger.getName().block())); + }); + } + @Configuration @EnableReactiveMethodSecurity static class MethodSecurityServiceEnabledConfig { @@ -484,4 +573,137 @@ public class PrePostReactiveMethodSecurityConfigurationTests { } + @EnableReactiveMethodSecurity + @Configuration + public static class AuthorizeResultConfig { + + @Bean + @Role(BeanDefinition.ROLE_INFRASTRUCTURE) + static Customizer skipValueTypes() { + return (f) -> f.setTargetVisitor(AuthorizationAdvisorProxyFactory.TargetVisitor.defaultsSkipValueTypes()); + } + + @Bean + FlightRepository flights() { + FlightRepository flights = new FlightRepository(); + Flight one = new Flight("1", 35000d, 35); + one.board(Flux.just("Marie Curie", "Kevin Mitnick", "Ada Lovelace")).block(); + flights.save(one).block(); + Flight two = new Flight("2", 32000d, 72); + two.board(Flux.just("Albert Einstein")).block(); + flights.save(two).block(); + return flights; + } + + @Bean + static MethodSecurityExpressionHandler expressionHandler() { + RoleHierarchy hierarchy = RoleHierarchyImpl.withRolePrefix("") + .role("airplane:read") + .implies("seating:read") + .build(); + DefaultMethodSecurityExpressionHandler expressionHandler = new DefaultMethodSecurityExpressionHandler(); + expressionHandler.setRoleHierarchy(hierarchy); + return expressionHandler; + } + + @Bean + Authz authz() { + return new Authz(); + } + + public static class Authz { + + public Mono isNotKevinMitnick(Passenger passenger) { + return passenger.getName().map((n) -> !"Kevin Mitnick".equals(n)); + } + + } + + } + + @AuthorizeReturnObject + static class FlightRepository { + + private final Map flights = new ConcurrentHashMap<>(); + + Flux findAll() { + return Flux.fromIterable(this.flights.values()); + } + + Mono findById(String id) { + return Mono.just(this.flights.get(id)); + } + + Mono save(Flight flight) { + this.flights.put(flight.getId(), flight); + return Mono.just(flight); + } + + Mono remove(String id) { + this.flights.remove(id); + return Mono.empty(); + } + + } + + @AuthorizeReturnObject + static class Flight { + + private final String id; + + private final Double altitude; + + private final Integer seats; + + private final List passengers = new ArrayList<>(); + + Flight(String id, Double altitude, Integer seats) { + this.id = id; + this.altitude = altitude; + this.seats = seats; + } + + String getId() { + return this.id; + } + + @PreAuthorize("hasAuthority('airplane:read')") + Mono getAltitude() { + return Mono.just(this.altitude); + } + + @PreAuthorize("hasAuthority('seating:read')") + Mono getSeats() { + return Mono.just(this.seats); + } + + @PostAuthorize("hasAuthority('seating:read')") + @PostFilter("@authz.isNotKevinMitnick(filterObject)") + Flux getPassengers() { + return Flux.fromIterable(this.passengers); + } + + @PreAuthorize("hasAuthority('seating:read')") + @PreFilter("filterObject.contains(' ')") + Mono board(Flux passengers) { + return passengers.doOnNext((passenger) -> this.passengers.add(new Passenger(passenger))).then(Mono.empty()); + } + + } + + public static class Passenger { + + String name; + + public Passenger(String name) { + this.name = name; + } + + @PreAuthorize("hasAuthority('airplane:read')") + public Mono getName() { + return Mono.just(this.name); + } + + } + }