@ -18,6 +18,10 @@ package org.springframework.security.config.annotation.method.configuration;
@@ -18,6 +18,10 @@ package org.springframework.security.config.annotation.method.configuration;
import java.lang.annotation.Retention ;
import java.lang.annotation.RetentionPolicy ;
import java.util.ArrayList ;
import java.util.List ;
import java.util.Map ;
import java.util.concurrent.ConcurrentHashMap ;
import org.junit.jupiter.api.Test ;
import org.junit.jupiter.api.extension.ExtendWith ;
@ -34,12 +38,18 @@ import org.springframework.context.annotation.Role;
@@ -34,12 +38,18 @@ import org.springframework.context.annotation.Role;
import org.springframework.security.access.AccessDeniedException ;
import org.springframework.security.access.PermissionEvaluator ;
import org.springframework.security.access.expression.method.DefaultMethodSecurityExpressionHandler ;
import org.springframework.security.access.expression.method.MethodSecurityExpressionHandler ;
import org.springframework.security.access.hierarchicalroles.RoleHierarchy ;
import org.springframework.security.access.hierarchicalroles.RoleHierarchyImpl ;
import org.springframework.security.access.prepost.PostAuthorize ;
import org.springframework.security.access.prepost.PostFilter ;
import org.springframework.security.access.prepost.PreAuthorize ;
import org.springframework.security.access.prepost.PreFilter ;
import org.springframework.security.authorization.AuthorizationDeniedException ;
import org.springframework.security.authorization.method.AuthorizationAdvisorProxyFactory ;
import org.springframework.security.authorization.method.AuthorizeReturnObject ;
import org.springframework.security.authorization.method.PrePostTemplateDefaults ;
import org.springframework.security.config.Customizer ;
import org.springframework.security.config.test.SpringTestContext ;
import org.springframework.security.config.test.SpringTestContextExtension ;
import org.springframework.security.core.annotation.AnnotationTemplateExpressionDefaults ;
@ -49,6 +59,7 @@ import org.springframework.test.context.junit.jupiter.SpringExtension;
@@ -49,6 +59,7 @@ import org.springframework.test.context.junit.jupiter.SpringExtension;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType ;
import static org.assertj.core.api.Assertions.assertThatNoException ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.ArgumentMatchers.eq ;
import static org.mockito.BDDMockito.given ;
@ -320,6 +331,84 @@ public class PrePostReactiveMethodSecurityConfigurationTests {
@@ -320,6 +331,84 @@ public class PrePostReactiveMethodSecurityConfigurationTests {
. containsExactly ( "dave" ) ;
}
@Test
@WithMockUser ( authorities = "airplane:read" )
public void findByIdWhenAuthorizedResultThenAuthorizes ( ) {
this . spring . register ( AuthorizeResultConfig . class ) . autowire ( ) ;
FlightRepository flights = this . spring . getContext ( ) . getBean ( FlightRepository . class ) ;
Flight flight = flights . findById ( "1" ) . block ( ) ;
assertThatNoException ( ) . isThrownBy ( flight : : getAltitude ) ;
assertThatNoException ( ) . isThrownBy ( flight : : getSeats ) ;
}
@Test
@WithMockUser ( authorities = "seating:read" )
public void findByIdWhenUnauthorizedResultThenDenies ( ) {
this . spring . register ( AuthorizeResultConfig . class ) . autowire ( ) ;
FlightRepository flights = this . spring . getContext ( ) . getBean ( FlightRepository . class ) ;
Flight flight = flights . findById ( "1" ) . block ( ) ;
assertThatNoException ( ) . isThrownBy ( flight : : getSeats ) ;
assertThatExceptionOfType ( AccessDeniedException . class ) . isThrownBy ( ( ) - > flight . getAltitude ( ) . block ( ) ) ;
}
@Test
@WithMockUser ( authorities = "seating:read" )
public void findAllWhenUnauthorizedResultThenDenies ( ) {
this . spring . register ( AuthorizeResultConfig . class ) . autowire ( ) ;
FlightRepository flights = this . spring . getContext ( ) . getBean ( FlightRepository . class ) ;
flights . findAll ( ) . collectList ( ) . block ( ) . forEach ( ( flight ) - > {
assertThatNoException ( ) . isThrownBy ( flight : : getSeats ) ;
assertThatExceptionOfType ( AccessDeniedException . class ) . isThrownBy ( ( ) - > flight . getAltitude ( ) . block ( ) ) ;
} ) ;
}
@Test
public void removeWhenAuthorizedResultThenRemoves ( ) {
this . spring . register ( AuthorizeResultConfig . class ) . autowire ( ) ;
FlightRepository flights = this . spring . getContext ( ) . getBean ( FlightRepository . class ) ;
flights . remove ( "1" ) ;
}
@Test
@WithMockUser ( authorities = "airplane:read" )
public void findAllWhenPostFilterThenFilters ( ) {
this . spring . register ( AuthorizeResultConfig . class ) . autowire ( ) ;
FlightRepository flights = this . spring . getContext ( ) . getBean ( FlightRepository . class ) ;
flights . findAll ( )
. collectList ( )
. block ( )
. forEach ( ( flight ) - > assertThat ( flight . getPassengers ( ) . collectList ( ) . block ( ) )
. extracting ( ( p ) - > p . getName ( ) . block ( ) )
. doesNotContain ( "Kevin Mitnick" ) ) ;
}
@Test
@WithMockUser ( authorities = "airplane:read" )
public void findAllWhenPreFilterThenFilters ( ) {
this . spring . register ( AuthorizeResultConfig . class ) . autowire ( ) ;
FlightRepository flights = this . spring . getContext ( ) . getBean ( FlightRepository . class ) ;
flights . findAll ( ) . collectList ( ) . block ( ) . forEach ( ( flight ) - > {
flight . board ( Flux . just ( "John" ) ) . block ( ) ;
assertThat ( flight . getPassengers ( ) . collectList ( ) . block ( ) ) . extracting ( ( p ) - > p . getName ( ) . block ( ) )
. doesNotContain ( "John" ) ;
flight . board ( Flux . just ( "John Doe" ) ) . block ( ) ;
assertThat ( flight . getPassengers ( ) . collectList ( ) . block ( ) ) . extracting ( ( p ) - > p . getName ( ) . block ( ) )
. contains ( "John Doe" ) ;
} ) ;
}
@Test
@WithMockUser ( authorities = "seating:read" )
public void findAllWhenNestedPreAuthorizeThenAuthorizes ( ) {
this . spring . register ( AuthorizeResultConfig . class ) . autowire ( ) ;
FlightRepository flights = this . spring . getContext ( ) . getBean ( FlightRepository . class ) ;
flights . findAll ( ) . collectList ( ) . block ( ) . forEach ( ( flight ) - > {
List < Passenger > passengers = flight . getPassengers ( ) . collectList ( ) . block ( ) ;
passengers . forEach ( ( passenger ) - > assertThatExceptionOfType ( AccessDeniedException . class )
. isThrownBy ( ( ) - > passenger . getName ( ) . block ( ) ) ) ;
} ) ;
}
@Configuration
@EnableReactiveMethodSecurity
static class MethodSecurityServiceEnabledConfig {
@ -484,4 +573,137 @@ public class PrePostReactiveMethodSecurityConfigurationTests {
@@ -484,4 +573,137 @@ public class PrePostReactiveMethodSecurityConfigurationTests {
}
@EnableReactiveMethodSecurity
@Configuration
public static class AuthorizeResultConfig {
@Bean
@Role ( BeanDefinition . ROLE_INFRASTRUCTURE )
static Customizer < AuthorizationAdvisorProxyFactory > skipValueTypes ( ) {
return ( f ) - > f . setTargetVisitor ( AuthorizationAdvisorProxyFactory . TargetVisitor . defaultsSkipValueTypes ( ) ) ;
}
@Bean
FlightRepository flights ( ) {
FlightRepository flights = new FlightRepository ( ) ;
Flight one = new Flight ( "1" , 35000d , 35 ) ;
one . board ( Flux . just ( "Marie Curie" , "Kevin Mitnick" , "Ada Lovelace" ) ) . block ( ) ;
flights . save ( one ) . block ( ) ;
Flight two = new Flight ( "2" , 32000d , 72 ) ;
two . board ( Flux . just ( "Albert Einstein" ) ) . block ( ) ;
flights . save ( two ) . block ( ) ;
return flights ;
}
@Bean
static MethodSecurityExpressionHandler expressionHandler ( ) {
RoleHierarchy hierarchy = RoleHierarchyImpl . withRolePrefix ( "" )
. role ( "airplane:read" )
. implies ( "seating:read" )
. build ( ) ;
DefaultMethodSecurityExpressionHandler expressionHandler = new DefaultMethodSecurityExpressionHandler ( ) ;
expressionHandler . setRoleHierarchy ( hierarchy ) ;
return expressionHandler ;
}
@Bean
Authz authz ( ) {
return new Authz ( ) ;
}
public static class Authz {
public Mono < Boolean > isNotKevinMitnick ( Passenger passenger ) {
return passenger . getName ( ) . map ( ( n ) - > ! "Kevin Mitnick" . equals ( n ) ) ;
}
}
}
@AuthorizeReturnObject
static class FlightRepository {
private final Map < String , Flight > flights = new ConcurrentHashMap < > ( ) ;
Flux < Flight > findAll ( ) {
return Flux . fromIterable ( this . flights . values ( ) ) ;
}
Mono < Flight > findById ( String id ) {
return Mono . just ( this . flights . get ( id ) ) ;
}
Mono < Flight > save ( Flight flight ) {
this . flights . put ( flight . getId ( ) , flight ) ;
return Mono . just ( flight ) ;
}
Mono < Void > remove ( String id ) {
this . flights . remove ( id ) ;
return Mono . empty ( ) ;
}
}
@AuthorizeReturnObject
static class Flight {
private final String id ;
private final Double altitude ;
private final Integer seats ;
private final List < Passenger > passengers = new ArrayList < > ( ) ;
Flight ( String id , Double altitude , Integer seats ) {
this . id = id ;
this . altitude = altitude ;
this . seats = seats ;
}
String getId ( ) {
return this . id ;
}
@PreAuthorize ( "hasAuthority('airplane:read')" )
Mono < Double > getAltitude ( ) {
return Mono . just ( this . altitude ) ;
}
@PreAuthorize ( "hasAuthority('seating:read')" )
Mono < Integer > getSeats ( ) {
return Mono . just ( this . seats ) ;
}
@PostAuthorize ( "hasAuthority('seating:read')" )
@PostFilter ( "@authz.isNotKevinMitnick(filterObject)" )
Flux < Passenger > getPassengers ( ) {
return Flux . fromIterable ( this . passengers ) ;
}
@PreAuthorize ( "hasAuthority('seating:read')" )
@PreFilter ( "filterObject.contains(' ')" )
Mono < Void > board ( Flux < String > passengers ) {
return passengers . doOnNext ( ( passenger ) - > this . passengers . add ( new Passenger ( passenger ) ) ) . then ( Mono . empty ( ) ) ;
}
}
public static class Passenger {
String name ;
public Passenger ( String name ) {
this . name = name ;
}
@PreAuthorize ( "hasAuthority('airplane:read')" )
public Mono < String > getName ( ) {
return Mono . just ( this . name ) ;
}
}
}