diff --git a/remoting/src/main/java/org/springframework/security/remoting/dns/JndiDnsResolver.java b/remoting/src/main/java/org/springframework/security/remoting/dns/JndiDnsResolver.java index fb406cb15c..efeabaa90a 100644 --- a/remoting/src/main/java/org/springframework/security/remoting/dns/JndiDnsResolver.java +++ b/remoting/src/main/java/org/springframework/security/remoting/dns/JndiDnsResolver.java @@ -61,16 +61,28 @@ public class JndiDnsResolver implements DnsResolver { @Override public String resolveServiceEntry(String serviceType, String domain) { - return resolveServiceEntry(serviceType, domain, this.ctxFactory.getCtx()); + return resolveServiceEntry(serviceType, domain, this.ctxFactory.getCtx()).getHostName(); } @Override public String resolveServiceIpAddress(String serviceType, String domain) { DirContext ctx = this.ctxFactory.getCtx(); - String hostname = resolveServiceEntry(serviceType, domain, ctx); + String hostname = resolveServiceEntry(serviceType, domain, ctx).getHostName(); return resolveIpAddress(hostname, ctx); } + /** + * @author Kathryn Newbould + * @since 5.4.1 + * @return String of ip address and port, format [ip_address]:[port] of service if found + * @throws DnsLookupException if not found + */ + public String resolveServiceIpAddressAndPort(String serviceType, String domain) { + DirContext ctx = this.ctxFactory.getCtx(); + ConnectionInfo hostInfo = resolveServiceEntry(serviceType, domain, ctx); + return resolveIpAddress(hostInfo.getHostName(), ctx) + ":" + hostInfo.getPort(); + } + // This method is needed, so that we can use only one DirContext for // resolveServiceIpAddress(). private String resolveIpAddress(String hostname, DirContext ctx) { @@ -88,8 +100,9 @@ public class JndiDnsResolver implements DnsResolver { // This method is needed, so that we can use only one DirContext for // resolveServiceIpAddress(). - private String resolveServiceEntry(String serviceType, String domain, DirContext ctx) { + private ConnectionInfo resolveServiceEntry(String serviceType, String domain, DirContext ctx) { String result = null; + String port = null; try { String query = new StringBuilder("_").append(serviceType).append("._tcp.").append(domain).toString(); Attribute dnsRecord = lookup(query, ctx, "SRV"); @@ -107,15 +120,18 @@ public class JndiDnsResolver implements DnsResolver { int priority = Integer.parseInt(record[0]); int weight = Integer.parseInt(record[1]); // we have a new highest Priority, so forget also the highest weight + int SERVICE_RECORD_PORT_INDEX = 2; if (priority < highestPriority || highestPriority == -1) { highestPriority = priority; highestWeight = weight; result = record[3].trim(); + port = record[SERVICE_RECORD_PORT_INDEX].trim(); } // same priority, but higher weight if (priority == highestPriority && weight > highestWeight) { highestWeight = weight; result = record[3].trim(); + port = record[SERVICE_RECORD_PORT_INDEX].trim(); } } } @@ -126,7 +142,7 @@ public class JndiDnsResolver implements DnsResolver { if (result.endsWith(".")) { result = result.substring(0, result.length() - 1); } - return result; + return new ConnectionInfo(result, port); } private Attribute lookup(String query, DirContext ictx, String recordType) { @@ -159,4 +175,22 @@ public class JndiDnsResolver implements DnsResolver { } + private static class ConnectionInfo { + private final String hostName; + private final String port; + + public ConnectionInfo(String hostName, String port) { + this.hostName = hostName; + this.port = port; + } + + public String getHostName() { + return hostName; + } + + public String getPort() { + return port; + } + } + } diff --git a/remoting/src/test/java/org/springframework/security/remoting/dns/JndiDnsResolverTests.java b/remoting/src/test/java/org/springframework/security/remoting/dns/JndiDnsResolverTests.java index c66f25013b..471a999d0a 100644 --- a/remoting/src/test/java/org/springframework/security/remoting/dns/JndiDnsResolverTests.java +++ b/remoting/src/test/java/org/springframework/security/remoting/dns/JndiDnsResolverTests.java @@ -95,6 +95,16 @@ public class JndiDnsResolverTests { assertThat(ipAddress).isEqualTo("63.246.7.80"); } + @Test + public void resolveServiceIpAddressWithPort() throws Exception { + BasicAttributes srvRecords = createSrvRecords(); + BasicAttributes aRecords = new BasicAttributes("A", "63.246.7.80"); + given(this.context.getAttributes("_ldap._tcp.springsource.com", new String[] { "SRV" })).willReturn(srvRecords); + given(this.context.getAttributes("kdc.springsource.com", new String[] { "A" })).willReturn(aRecords); + String ipAddress = this.dnsResolver.resolveServiceIpAddressAndPort("ldap", "springsource.com"); + assertThat(ipAddress).isEqualTo("63.246.7.80:389"); + } + @Test public void testUnknowError() throws Exception { given(this.context.getAttributes(any(String.class), any(String[].class)))