@ -16,27 +16,17 @@
@@ -16,27 +16,17 @@
package org.springframework.security.saml2.provider.service.registration ;
import java.io.IOException ;
import java.io.InputStream ;
import java.security.cert.CertificateFactory ;
import java.security.cert.X509Certificate ;
import java.util.Collection ;
import java.util.Iterator ;
import java.util.List ;
import org.junit.jupiter.api.AfterEach ;
import org.junit.jupiter.api.BeforeEach ;
import org.junit.jupiter.api.Test ;
import org.springframework.core.io.ClassPathResource ;
import org.springframework.core.serializer.DefaultSerializer ;
import org.springframework.core.serializer.Serializer ;
import org.springframework.jdbc.core.JdbcOperations ;
import org.springframework.jdbc.core.JdbcTemplate ;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase ;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder ;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType ;
import org.springframework.security.saml2.core.Saml2X509Credential ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException ;
@ -48,41 +38,21 @@ class JdbcAssertingPartyMetadataRepositoryTests {
@@ -48,41 +38,21 @@ class JdbcAssertingPartyMetadataRepositoryTests {
private static final String SCHEMA_SQL_RESOURCE = "org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql" ;
private static final String SAVE_SQL = "INSERT INTO saml2_asserting_party_metadata ("
+ JdbcAssertingPartyMetadataRepository . COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" ;
private static final String ENTITY_ID = "https://localhost/simplesaml/saml2/idp/metadata.php" ;
private static final String SINGLE_SIGNON_URL = "https://localhost/SSO" ;
private static final String SINGLE_SIGNON_BINDING = Saml2MessageBinding . REDIRECT . getUrn ( ) ;
private static final boolean SINGLE_SIGNON_SIGN_REQUEST = false ;
private static final String SINGLE_LOGOUT_URL = "https://localhost/SLO" ;
private static final String SINGLE_LOGOUT_RESPONSE_URL = "https://localhost/SLO/response" ;
private static final String SINGLE_LOGOUT_BINDING = Saml2MessageBinding . REDIRECT . getUrn ( ) ;
private static final List < String > SIGNING_ALGORITHMS = List . of ( "http://www.w3.org/2001/04/xmldsig-more#rsa-sha512" ) ;
private X509Certificate certificate ;
private EmbeddedDatabase db ;
private JdbcAssertingPartyMetadataRepository repository ;
private JdbcOperations jdbcOperations ;
private final Serializer < Object > serializer = new DefaultSerializer ( ) ;
private final AssertingPartyMetadata metadata = TestRelyingPartyRegistrations . full ( )
. build ( )
. getAssertingPartyMetadata ( ) ;
@BeforeEach
void setUp ( ) {
this . db = createDb ( ) ;
this . jdbcOperations = new JdbcTemplate ( this . db ) ;
this . repository = new JdbcAssertingPartyMetadataRepository ( this . jdbcOperations ) ;
this . certificate = loadCertificate ( "rsa.crt" ) ;
}
@AfterEach
@ -109,26 +79,12 @@ class JdbcAssertingPartyMetadataRepositoryTests {
@@ -109,26 +79,12 @@ class JdbcAssertingPartyMetadataRepositoryTests {
}
@Test
void findByEntityId ( ) throws IOException {
this . jdbcOperations . update ( SAVE_SQL , ENTITY_ID , SINGLE_SIGNON_URL , SINGLE_SIGNON_BINDING ,
SINGLE_SIGNON_SIGN_REQUEST , this . serializer . serializeToByteArray ( SIGNING_ALGORITHMS ) ,
this . serializer . serializeToByteArray ( asCredentials ( this . certificate ) ) ,
this . serializer . serializeToByteArray ( asCredentials ( this . certificate ) ) , SINGLE_LOGOUT_URL ,
SINGLE_LOGOUT_RESPONSE_URL , SINGLE_LOGOUT_BINDING ) ;
void findByEntityId ( ) {
this . repository . save ( this . metadata ) ;
AssertingPartyMetadata found = this . repository . findByEntityId ( ENTITY_ID ) ;
AssertingPartyMetadata found = this . repository . findByEntityId ( this . metadata . getEntityId ( ) ) ;
assertThat ( found ) . isNotNull ( ) ;
assertThat ( found . getEntityId ( ) ) . isEqualTo ( ENTITY_ID ) ;
assertThat ( found . getSingleSignOnServiceLocation ( ) ) . isEqualTo ( SINGLE_SIGNON_URL ) ;
assertThat ( found . getSingleSignOnServiceBinding ( ) . getUrn ( ) ) . isEqualTo ( SINGLE_SIGNON_BINDING ) ;
assertThat ( found . getWantAuthnRequestsSigned ( ) ) . isEqualTo ( SINGLE_SIGNON_SIGN_REQUEST ) ;
assertThat ( found . getSingleLogoutServiceLocation ( ) ) . isEqualTo ( SINGLE_LOGOUT_URL ) ;
assertThat ( found . getSingleLogoutServiceResponseLocation ( ) ) . isEqualTo ( SINGLE_LOGOUT_RESPONSE_URL ) ;
assertThat ( found . getSingleLogoutServiceBinding ( ) . getUrn ( ) ) . isEqualTo ( SINGLE_LOGOUT_BINDING ) ;
assertThat ( found . getSigningAlgorithms ( ) ) . contains ( SIGNING_ALGORITHMS . get ( 0 ) ) ;
assertThat ( found . getVerificationX509Credentials ( ) ) . hasSize ( 1 ) ;
assertThat ( found . getEncryptionX509Credentials ( ) ) . hasSize ( 1 ) ;
assertAssertingPartyEquals ( found , this . metadata ) ;
}
@Test
@ -138,28 +94,30 @@ class JdbcAssertingPartyMetadataRepositoryTests {
@@ -138,28 +94,30 @@ class JdbcAssertingPartyMetadataRepositoryTests {
}
@Test
void iterator ( ) throws IOException {
this . jdbcOperations . update ( SAVE_SQL , ENTITY_ID , SINGLE_SIGNON_URL , SINGLE_SIGNON_BINDING ,
SINGLE_SIGNON_SIGN_REQUEST , this . serializer . serializeToByteArray ( SIGNING_ALGORITHMS ) ,
this . serializer . serializeToByteArray ( asCredentials ( this . certificate ) ) ,
this . serializer . serializeToByteArray ( asCredentials ( this . certificate ) ) , SINGLE_LOGOUT_URL ,
SINGLE_LOGOUT_RESPONSE_URL , SINGLE_LOGOUT_BINDING ) ;
this . jdbcOperations . update ( SAVE_SQL , "https://localhost/simplesaml2/saml2/idp/metadata.php" , SINGLE_SIGNON_URL ,
SINGLE_SIGNON_BINDING , SINGLE_SIGNON_SIGN_REQUEST ,
this . serializer . serializeToByteArray ( SIGNING_ALGORITHMS ) ,
this . serializer . serializeToByteArray ( asCredentials ( this . certificate ) ) ,
this . serializer . serializeToByteArray ( asCredentials ( this . certificate ) ) , SINGLE_LOGOUT_URL ,
SINGLE_LOGOUT_RESPONSE_URL , SINGLE_LOGOUT_BINDING ) ;
void iterator ( ) {
AssertingPartyMetadata second = RelyingPartyRegistration . withAssertingPartyMetadata ( this . metadata )
. assertingPartyMetadata ( ( a ) - > a . entityId ( "https://example.org/idp" ) )
. build ( )
. getAssertingPartyMetadata ( ) ;
this . repository . save ( this . metadata ) ;
this . repository . save ( second ) ;
Iterator < AssertingPartyMetadata > iterator = this . repository . iterator ( ) ;
AssertingPartyMetadata first = iterator . next ( ) ;
assertThat ( first ) . isNotNull ( ) ;
AssertingPartyMetadata second = iterator . next ( ) ;
assertThat ( second ) . isNotNull ( ) ;
assertAssertingPartyEquals ( iterator . next ( ) , this . metadata ) ;
assertAssertingPartyEquals ( iterator . next ( ) , second ) ;
assertThat ( iterator . hasNext ( ) ) . isFalse ( ) ;
}
@Test
void saveWhenExistingThenUpdates ( ) {
this . repository . save ( this . metadata ) ;
boolean existing = this . metadata . getWantAuthnRequestsSigned ( ) ;
this . repository . save ( this . metadata . mutate ( ) . wantAuthnRequestsSigned ( ! existing ) . build ( ) ) ;
boolean updated = this . repository . findByEntityId ( this . metadata . getEntityId ( ) ) . getWantAuthnRequestsSigned ( ) ;
assertThat ( existing ) . isNotEqualTo ( updated ) ;
}
private static EmbeddedDatabase createDb ( ) {
return createDb ( SCHEMA_SQL_RESOURCE ) ;
}
@ -175,19 +133,19 @@ class JdbcAssertingPartyMetadataRepositoryTests {
@@ -175,19 +133,19 @@ class JdbcAssertingPartyMetadataRepositoryTests {
// @formatter:on
}
private X509Certificate loadCertificate ( String path ) {
try ( InputStream is = new ClassPathResource ( path ) . getInputStream ( ) ) {
CertificateFactory factory = CertificateFactory . getInstance ( "X.509" ) ;
return ( X509Certificate ) factory . generateCertificate ( is ) ;
}
catch ( Exception ex ) {
throw new RuntimeException ( "Error loading certificate from " + path , ex ) ;
}
}
private Collection < Saml2X509Credential > asCredentials ( X509Certificate certificate ) {
return List . of ( new Saml2X509Credential ( certificate , Saml2X509Credential . Saml2X509CredentialType . ENCRYPTION ,
Saml2X509Credential . Saml2X509CredentialType . VERIFICATION ) ) ;
private void assertAssertingPartyEquals ( AssertingPartyMetadata found , AssertingPartyMetadata expected ) {
assertThat ( found ) . isNotNull ( ) ;
assertThat ( found . getEntityId ( ) ) . isEqualTo ( expected . getEntityId ( ) ) ;
assertThat ( found . getSingleSignOnServiceLocation ( ) ) . isEqualTo ( expected . getSingleSignOnServiceLocation ( ) ) ;
assertThat ( found . getSingleSignOnServiceBinding ( ) ) . isEqualTo ( expected . getSingleSignOnServiceBinding ( ) ) ;
assertThat ( found . getWantAuthnRequestsSigned ( ) ) . isEqualTo ( expected . getWantAuthnRequestsSigned ( ) ) ;
assertThat ( found . getSingleLogoutServiceLocation ( ) ) . isEqualTo ( expected . getSingleLogoutServiceLocation ( ) ) ;
assertThat ( found . getSingleLogoutServiceResponseLocation ( ) )
. isEqualTo ( expected . getSingleLogoutServiceResponseLocation ( ) ) ;
assertThat ( found . getSingleLogoutServiceBinding ( ) ) . isEqualTo ( expected . getSingleLogoutServiceBinding ( ) ) ;
assertThat ( found . getSigningAlgorithms ( ) ) . containsAll ( expected . getSigningAlgorithms ( ) ) ;
assertThat ( found . getVerificationX509Credentials ( ) ) . containsAll ( expected . getVerificationX509Credentials ( ) ) ;
assertThat ( found . getEncryptionX509Credentials ( ) ) . containsAll ( expected . getEncryptionX509Credentials ( ) ) ;
}
}