@ -19,11 +19,15 @@ package org.springframework.security.saml2.provider.service.authentication;
@@ -19,11 +19,15 @@ package org.springframework.security.saml2.provider.service.authentication;
import java.io.ByteArrayOutputStream ;
import java.io.IOException ;
import java.io.ObjectOutputStream ;
import java.io.StringReader ;
import java.time.Instant ;
import java.util.Arrays ;
import java.util.Collections ;
import java.util.LinkedHashMap ;
import java.util.List ;
import java.util.Map ;
import javax.xml.parsers.DocumentBuilder ;
import javax.xml.parsers.DocumentBuilderFactory ;
import org.hamcrest.BaseMatcher ;
import org.hamcrest.Description ;
@ -33,27 +37,40 @@ import org.joda.time.Duration;
@@ -33,27 +37,40 @@ import org.joda.time.Duration;
import org.junit.Rule ;
import org.junit.Test ;
import org.junit.rules.ExpectedException ;
import org.opensaml.core.xml.XMLObject ;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport ;
import org.opensaml.core.xml.io.Marshaller ;
import org.opensaml.saml.saml2.core.Assertion ;
import org.opensaml.saml.saml2.core.AttributeStatement ;
import org.opensaml.saml.saml2.core.AttributeValue ;
import org.opensaml.saml.saml2.core.EncryptedAssertion ;
import org.opensaml.saml.saml2.core.EncryptedID ;
import org.opensaml.saml.saml2.core.NameID ;
import org.opensaml.saml.saml2.core.Response ;
import org.w3c.dom.Document ;
import org.w3c.dom.Element ;
import org.xml.sax.InputSource ;
import org.springframework.security.core.Authentication ;
import org.springframework.security.saml2.credentials.Saml2X509Credential ;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion ;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements ;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted ;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response ;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed ;
import static org.junit.Assert.assertEquals ;
import static org.junit.Assert.assertTrue ;
import static org.mockito.ArgumentMatchers.any ;
import static org.mockito.Mockito.atLeastOnce ;
import static org.mockito.Mockito.mock ;
import static org.mockito.Mockito.verify ;
import static org.mockito.Mockito.when ;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential ;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential ;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential ;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential ;
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential ;
import static org.springframework.test.util.AssertionErrors.assertEquals ;
import static org.springframework.test.util.AssertionErrors.assertTrue ;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion ;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements ;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted ;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response ;
import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed ;
import static org.springframework.util.StringUtils.hasText ;
/ * *
@ -203,24 +220,48 @@ public class OpenSamlAuthenticationProviderTests {
@@ -203,24 +220,48 @@ public class OpenSamlAuthenticationProviderTests {
public void authenticateWhenAssertionContainsAttributesThenItSucceeds ( ) {
Response response = response ( ) ;
Assertion assertion = assertion ( ) ;
attributeStatements ( ) . forEach ( as - > assertion . getAttributeStatements ( ) . add ( as ) ) ;
List < AttributeStatement > attributes = attributeStatements ( ) ;
assertion . getAttributeStatements ( ) . addAll ( attributes ) ;
signed ( assertion , assertingPartySigningCredential ( ) , RELYING_PARTY_ENTITY_ID ) ;
response . getAssertions ( ) . add ( assertion ) ;
Saml2AuthenticationToken token = token ( response , relyingPartyVerifyingCredential ( ) ) ;
Authentication authentication = this . provider . authenticate ( token ) ;
Saml2AuthenticatedPrincipal principal = ( Saml2AuthenticatedPrincipal ) authentication . getPrincipal ( ) ;
Map < String , Object > attributes = new LinkedHashMap < > ( ) ;
attributes . put ( "email" , Arrays . asList ( "john.doe@example.com" , "doe.john@example.com" ) ) ;
attributes . put ( "name" , Collections . singletonList ( "John Doe" ) ) ;
attributes . put ( "age" , Collections . singletonList ( 21 ) ) ;
attributes . put ( "website" , Collections . singletonList ( "https://johndoe.com/" ) ) ;
attributes . put ( "registered" , Collections . singletonList ( true ) ) ;
Map < String , Object > expected = new LinkedHashMap < > ( ) ;
expected . put ( "email" , Arrays . asList ( "john.doe@example.com" , "doe.john@example.com" ) ) ;
expected . put ( "name" , Collections . singletonList ( "John Doe" ) ) ;
expected . put ( "age" , Collections . singletonList ( 21 ) ) ;
expected . put ( "website" , Collections . singletonList ( "https://johndoe.com/" ) ) ;
expected . put ( "registered" , Collections . singletonList ( true ) ) ;
Instant registeredDate = Instant . ofEpochMilli ( DateTime . parse ( "1970-01-01T00:00:00Z" ) . getMillis ( ) ) ;
attributes . put ( "registeredDate" , Collections . singletonList ( registeredDate ) ) ;
expected . put ( "registeredDate" , Collections . singletonList ( registeredDate ) ) ;
assertEquals ( "Values should be equal" , "John Doe" , principal . getFirstAttribute ( "name" ) ) ;
assertTrue ( "Attributes should be equal" , attributes . equals ( principal . getAttributes ( ) ) ) ;
assertEquals ( "John Doe" , principal . getFirstAttribute ( "name" ) ) ;
assertEquals ( expected , principal . getAttributes ( ) ) ;
}
@Test
public void authenticateWhenAttributeValueMarshallerConfiguredThenUses ( ) throws Exception {
Response response = response ( ) ;
Assertion assertion = assertion ( ) ;
List < AttributeStatement > attributes = attributeStatements ( ) ;
assertion . getAttributeStatements ( ) . addAll ( attributes ) ;
signed ( assertion , assertingPartySigningCredential ( ) , RELYING_PARTY_ENTITY_ID ) ;
response . getAssertions ( ) . add ( assertion ) ;
Saml2AuthenticationToken token = token ( response , relyingPartyVerifyingCredential ( ) ) ;
Element attributeElement = element ( "<element>value</element>" ) ;
Marshaller marshaller = mock ( Marshaller . class ) ;
when ( marshaller . marshall ( any ( XMLObject . class ) ) ) . thenReturn ( attributeElement ) ;
try {
XMLObjectProviderRegistrySupport . getMarshallerFactory ( ) . registerMarshaller ( AttributeValue . DEFAULT_ELEMENT_NAME , marshaller ) ;
this . provider . authenticate ( token ) ;
verify ( marshaller , atLeastOnce ( ) ) . marshall ( any ( XMLObject . class ) ) ;
} finally {
XMLObjectProviderRegistrySupport . getMarshallerFactory ( ) . deregisterMarshaller ( AttributeValue . DEFAULT_ELEMENT_NAME ) ;
}
}
@Test
@ -352,4 +393,11 @@ public class OpenSamlAuthenticationProviderTests {
@@ -352,4 +393,11 @@ public class OpenSamlAuthenticationProviderTests {
return new Saml2AuthenticationToken ( payload ,
DESTINATION , ASSERTING_PARTY_ENTITY_ID , RELYING_PARTY_ENTITY_ID , Arrays . asList ( credentials ) ) ;
}
private static Element element ( String xml ) throws Exception {
DocumentBuilderFactory factory = DocumentBuilderFactory . newInstance ( ) ;
DocumentBuilder builder = factory . newDocumentBuilder ( ) ;
Document doc = builder . parse ( new InputSource ( new StringReader ( xml ) ) ) ;
return doc . getDocumentElement ( ) ;
}
}