From 561c7867264e6370a74c52fa96babddfcfc0827a Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Fri, 23 Aug 2024 15:53:50 -0600 Subject: [PATCH] Repair Flaky Tests The issue turned out to be that OpenSAML first sends two HEAD requests before sending a GET to retrieve the metadata. The way the MockWebServer dispatcher was configured, it would send back the metadata on each request. This created a situation where sockets were being closed by the client before the server had sent all the response, resulting in a broken pipe. The tests would succeed most of the time due to lucky timing between the client closing the socket and the server having sent all of its (unrequested) data. This version sends an expected HEAD response when requested. Issue gh-15395 --- ...AssertingPartyMetadataRepositoryTests.java | 161 +++++++++++------- ...AssertingPartyMetadataRepositoryTests.java | 161 +++++++++++------- 2 files changed, 194 insertions(+), 128 deletions(-) diff --git a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml4AssertingPartyMetadataRepositoryTests.java b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml4AssertingPartyMetadataRepositoryTests.java index b4663507b6..2da900fd7c 100644 --- a/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml4AssertingPartyMetadataRepositoryTests.java +++ b/saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml4AssertingPartyMetadataRepositoryTests.java @@ -20,16 +20,23 @@ import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.InputStreamReader; +import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import net.shibboleth.utilities.java.support.xml.SerializeSupport; +import okhttp3.mockwebserver.Dispatcher; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; -import org.junit.jupiter.api.BeforeEach; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; @@ -68,52 +75,59 @@ public class OpenSaml4AssertingPartyMetadataRepositoryTests { OpenSamlInitializationService.initialize(); } - private String metadata; + private static MetadataDispatcher dispatcher = new MetadataDispatcher() + .addResponse("/entity.xml", readFile("test-metadata.xml")) + .addResponse("/entities.xml", readFile("test-entitiesdescriptor.xml")); - private String entitiesDescriptor; + private static MockWebServer web = new MockWebServer(); - @BeforeEach - public void setup() throws Exception { - ClassPathResource resource = new ClassPathResource("test-metadata.xml"); - try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { - this.metadata = reader.lines().collect(Collectors.joining()); + private static String readFile(String fileName) { + try { + ClassPathResource resource = new ClassPathResource(fileName); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { + return reader.lines().collect(Collectors.joining()); + } } - resource = new ClassPathResource("test-entitiesdescriptor.xml"); - try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { - this.entitiesDescriptor = reader.lines().collect(Collectors.joining()); + catch (IOException ex) { + throw new UncheckedIOException(ex); } } + @BeforeAll + public static void start() throws Exception { + web.setDispatcher(dispatcher); + web.start(); + } + + @AfterAll + public static void shutdown() throws Exception { + web.shutdown(); + } + @Test public void withMetadataUrlLocationWhenResolvableThenFindByEntityIdReturns() throws Exception { - try (MockWebServer server = new MockWebServer()) { - enqueue(server, this.metadata, 3); - AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository - .withTrustedMetadataLocation(server.url("/").toString()) - .build(); - AssertingPartyMetadata party = parties.findByEntityId("https://idp.example.com/idp/shibboleth"); - assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth"); - assertThat(party.getSingleSignOnServiceLocation()) - .isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO"); - assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST); - assertThat(party.getVerificationX509Credentials()).hasSize(1); - assertThat(party.getEncryptionX509Credentials()).hasSize(1); - } + AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository + .withTrustedMetadataLocation(web.url("/entity.xml").toString()) + .build(); + AssertingPartyMetadata party = parties.findByEntityId("https://idp.example.com/idp/shibboleth"); + assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth"); + assertThat(party.getSingleSignOnServiceLocation()) + .isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO"); + assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(party.getVerificationX509Credentials()).hasSize(1); + assertThat(party.getEncryptionX509Credentials()).hasSize(1); } @Test public void withMetadataUrlLocationnWhenResolvableThenIteratorReturns() throws Exception { - try (MockWebServer server = new MockWebServer()) { - enqueue(server, this.entitiesDescriptor, 3); - List parties = new ArrayList<>(); - OpenSaml4AssertingPartyMetadataRepository.withTrustedMetadataLocation(server.url("/").toString()) - .build() - .iterator() - .forEachRemaining(parties::add); - assertThat(parties).hasSize(2); - assertThat(parties).extracting(AssertingPartyMetadata::getEntityId) - .contains("https://ap.example.org/idp/shibboleth", "https://idp.example.com/idp/shibboleth"); - } + List parties = new ArrayList<>(); + OpenSaml4AssertingPartyMetadataRepository.withTrustedMetadataLocation(web.url("/entities.xml").toString()) + .build() + .iterator() + .forEachRemaining(parties::add); + assertThat(parties).hasSize(2); + assertThat(parties).extracting(AssertingPartyMetadata::getEntityId) + .contains("https://ap.example.org/idp/shibboleth", "https://idp.example.com/idp/shibboleth"); } @Test @@ -128,12 +142,10 @@ public class OpenSaml4AssertingPartyMetadataRepositoryTests { @Test public void withMetadataUrlLocationWhenMalformedResponseThenSaml2Exception() throws Exception { - try (MockWebServer server = new MockWebServer()) { - enqueue(server, "malformed", 3); - String url = server.url("/").toString(); - assertThatExceptionOfType(Saml2Exception.class) - .isThrownBy(() -> OpenSaml4AssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build()); - } + dispatcher.addResponse("/malformed", "malformed"); + String url = web.url("/malformed").toString(); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> OpenSaml4AssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build()); } @Test @@ -211,14 +223,13 @@ public class OpenSaml4AssertingPartyMetadataRepositoryTests { String serialized = serialize(descriptor); Credential credential = TestOpenSamlObjects .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); - try (MockWebServer server = new MockWebServer()) { - enqueue(server, serialized, 3); - AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository - .withTrustedMetadataLocation(server.url("/").toString()) - .verificationCredentials((c) -> c.add(credential)) - .build(); - assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); - } + String endpoint = "/" + UUID.randomUUID().toString(); + dispatcher.addResponse(endpoint, serialized); + AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository + .withTrustedMetadataLocation(web.url(endpoint).toString()) + .verificationCredentials((c) -> c.add(credential)) + .build(); + assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); } @Test @@ -230,13 +241,12 @@ public class OpenSaml4AssertingPartyMetadataRepositoryTests { String serialized = serialize(descriptor); Credential credential = TestOpenSamlObjects .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); - try (MockWebServer server = new MockWebServer()) { - enqueue(server, serialized, 3); - assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> OpenSaml4AssertingPartyMetadataRepository - .withTrustedMetadataLocation(server.url("/").toString()) - .verificationCredentials((c) -> c.add(credential)) - .build()); - } + String endpoint = "/" + UUID.randomUUID().toString(); + dispatcher.addResponse(endpoint, serialized); + assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> OpenSaml4AssertingPartyMetadataRepository + .withTrustedMetadataLocation(web.url(endpoint).toString()) + .verificationCredentials((c) -> c.add(credential)) + .build()); } @Test @@ -326,14 +336,13 @@ public class OpenSaml4AssertingPartyMetadataRepositoryTests { String serialized = serialize(descriptor); Credential credential = TestOpenSamlObjects .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); - try (MockWebServer server = new MockWebServer()) { - enqueue(server, serialized, 3); - AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository - .withMetadataLocation(server.url("/").toString()) - .verificationCredentials((c) -> c.add(credential)) - .build(); - assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); - } + String endpoint = "/" + UUID.randomUUID().toString(); + dispatcher.addResponse(endpoint, serialized); + AssertingPartyMetadataRepository parties = OpenSaml4AssertingPartyMetadataRepository + .withMetadataLocation(web.url(endpoint).toString()) + .verificationCredentials((c) -> c.add(credential)) + .build(); + assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); } private static String serialize(XMLObject object) { @@ -353,4 +362,28 @@ public class OpenSaml4AssertingPartyMetadataRepositoryTests { } } + private static final class MetadataDispatcher extends Dispatcher { + + private final MockResponse head = new MockResponse(); + + private final Map responses = new ConcurrentHashMap<>(); + + private MetadataDispatcher() { + } + + @Override + public MockResponse dispatch(RecordedRequest request) throws InterruptedException { + if ("HEAD".equals(request.getMethod())) { + return this.head; + } + return this.responses.get(request.getPath()); + } + + private MetadataDispatcher addResponse(String path, String body) { + this.responses.put(path, new MockResponse().setBody(body).setResponseCode(200)); + return this; + } + + } + } diff --git a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepositoryTests.java b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepositoryTests.java index 27c0fd5adb..c01bb82ea6 100644 --- a/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepositoryTests.java +++ b/saml2/saml2-service-provider/src/opensaml5Test/java/org/springframework/security/saml2/provider/service/registration/OpenSaml5AssertingPartyMetadataRepositoryTests.java @@ -20,16 +20,23 @@ import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.InputStreamReader; +import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; import net.shibboleth.shared.xml.SerializeSupport; +import okhttp3.mockwebserver.Dispatcher; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; -import org.junit.jupiter.api.BeforeEach; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; @@ -68,52 +75,59 @@ public class OpenSaml5AssertingPartyMetadataRepositoryTests { OpenSamlInitializationService.initialize(); } - private String metadata; + private static MetadataDispatcher dispatcher = new MetadataDispatcher() + .addResponse("/entity.xml", readFile("test-metadata.xml")) + .addResponse("/entities.xml", readFile("test-entitiesdescriptor.xml")); - private String entitiesDescriptor; + private static MockWebServer web = new MockWebServer(); - @BeforeEach - public void setup() throws Exception { - ClassPathResource resource = new ClassPathResource("test-metadata.xml"); - try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { - this.metadata = reader.lines().collect(Collectors.joining()); + private static String readFile(String fileName) { + try { + ClassPathResource resource = new ClassPathResource(fileName); + try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { + return reader.lines().collect(Collectors.joining()); + } } - resource = new ClassPathResource("test-entitiesdescriptor.xml"); - try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) { - this.entitiesDescriptor = reader.lines().collect(Collectors.joining()); + catch (IOException ex) { + throw new UncheckedIOException(ex); } } + @BeforeAll + public static void start() throws Exception { + web.setDispatcher(dispatcher); + web.start(); + } + + @AfterAll + public static void shutdown() throws Exception { + web.shutdown(); + } + @Test public void withMetadataUrlLocationWhenResolvableThenFindByEntityIdReturns() throws Exception { - try (MockWebServer server = new MockWebServer()) { - enqueue(server, this.metadata, 3); - AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository - .withTrustedMetadataLocation(server.url("/").toString()) - .build(); - AssertingPartyMetadata party = parties.findByEntityId("https://idp.example.com/idp/shibboleth"); - assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth"); - assertThat(party.getSingleSignOnServiceLocation()) - .isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO"); - assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST); - assertThat(party.getVerificationX509Credentials()).hasSize(1); - assertThat(party.getEncryptionX509Credentials()).hasSize(1); - } + AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository + .withTrustedMetadataLocation(web.url("/entity.xml").toString()) + .build(); + AssertingPartyMetadata party = parties.findByEntityId("https://idp.example.com/idp/shibboleth"); + assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth"); + assertThat(party.getSingleSignOnServiceLocation()) + .isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO"); + assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST); + assertThat(party.getVerificationX509Credentials()).hasSize(1); + assertThat(party.getEncryptionX509Credentials()).hasSize(1); } @Test public void withMetadataUrlLocationnWhenResolvableThenIteratorReturns() throws Exception { - try (MockWebServer server = new MockWebServer()) { - enqueue(server, this.entitiesDescriptor, 3); - List parties = new ArrayList<>(); - OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation(server.url("/").toString()) - .build() - .iterator() - .forEachRemaining(parties::add); - assertThat(parties).hasSize(2); - assertThat(parties).extracting(AssertingPartyMetadata::getEntityId) - .contains("https://ap.example.org/idp/shibboleth", "https://idp.example.com/idp/shibboleth"); - } + List parties = new ArrayList<>(); + OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation(web.url("/entities.xml").toString()) + .build() + .iterator() + .forEachRemaining(parties::add); + assertThat(parties).hasSize(2); + assertThat(parties).extracting(AssertingPartyMetadata::getEntityId) + .contains("https://ap.example.org/idp/shibboleth", "https://idp.example.com/idp/shibboleth"); } @Test @@ -128,12 +142,10 @@ public class OpenSaml5AssertingPartyMetadataRepositoryTests { @Test public void withMetadataUrlLocationWhenMalformedResponseThenSaml2Exception() throws Exception { - try (MockWebServer server = new MockWebServer()) { - enqueue(server, "malformed", 3); - String url = server.url("/").toString(); - assertThatExceptionOfType(Saml2Exception.class) - .isThrownBy(() -> OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build()); - } + dispatcher.addResponse("/malformed", "malformed"); + String url = web.url("/malformed").toString(); + assertThatExceptionOfType(Saml2Exception.class) + .isThrownBy(() -> OpenSaml5AssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build()); } @Test @@ -211,14 +223,13 @@ public class OpenSaml5AssertingPartyMetadataRepositoryTests { String serialized = serialize(descriptor); Credential credential = TestOpenSamlObjects .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); - try (MockWebServer server = new MockWebServer()) { - enqueue(server, serialized, 3); - AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository - .withTrustedMetadataLocation(server.url("/").toString()) - .verificationCredentials((c) -> c.add(credential)) - .build(); - assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); - } + String endpoint = "/" + UUID.randomUUID().toString(); + dispatcher.addResponse(endpoint, serialized); + AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository + .withTrustedMetadataLocation(web.url(endpoint).toString()) + .verificationCredentials((c) -> c.add(credential)) + .build(); + assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); } @Test @@ -230,13 +241,12 @@ public class OpenSaml5AssertingPartyMetadataRepositoryTests { String serialized = serialize(descriptor); Credential credential = TestOpenSamlObjects .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); - try (MockWebServer server = new MockWebServer()) { - enqueue(server, serialized, 3); - assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> OpenSaml5AssertingPartyMetadataRepository - .withTrustedMetadataLocation(server.url("/").toString()) - .verificationCredentials((c) -> c.add(credential)) - .build()); - } + String endpoint = "/" + UUID.randomUUID().toString(); + dispatcher.addResponse(endpoint, serialized); + assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> OpenSaml5AssertingPartyMetadataRepository + .withTrustedMetadataLocation(web.url(endpoint).toString()) + .verificationCredentials((c) -> c.add(credential)) + .build()); } @Test @@ -326,14 +336,13 @@ public class OpenSaml5AssertingPartyMetadataRepositoryTests { String serialized = serialize(descriptor); Credential credential = TestOpenSamlObjects .getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID()); - try (MockWebServer server = new MockWebServer()) { - enqueue(server, serialized, 3); - AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository - .withMetadataLocation(server.url("/").toString()) - .verificationCredentials((c) -> c.add(credential)) - .build(); - assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); - } + String endpoint = "/" + UUID.randomUUID().toString(); + dispatcher.addResponse(endpoint, serialized); + AssertingPartyMetadataRepository parties = OpenSaml5AssertingPartyMetadataRepository + .withMetadataLocation(web.url(endpoint).toString()) + .verificationCredentials((c) -> c.add(credential)) + .build(); + assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull(); } private static String serialize(XMLObject object) { @@ -353,4 +362,28 @@ public class OpenSaml5AssertingPartyMetadataRepositoryTests { } } + private static final class MetadataDispatcher extends Dispatcher { + + private final MockResponse head = new MockResponse(); + + private final Map responses = new ConcurrentHashMap<>(); + + private MetadataDispatcher() { + } + + @Override + public MockResponse dispatch(RecordedRequest request) throws InterruptedException { + if ("HEAD".equals(request.getMethod())) { + return this.head; + } + return this.responses.get(request.getPath()); + } + + private MetadataDispatcher addResponse(String path, String body) { + this.responses.put(path, new MockResponse().setBody(body).setResponseCode(200)); + return this; + } + + } + }