diff --git a/spring-messaging/src/main/java/org/springframework/messaging/MessageHeaders.java b/spring-messaging/src/main/java/org/springframework/messaging/MessageHeaders.java index 0ebea02993b..85697ea437f 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/MessageHeaders.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/MessageHeaders.java @@ -20,11 +20,10 @@ import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; -import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; -import java.util.List; +import java.util.HashSet; import java.util.Map; import java.util.Set; import java.util.UUID; @@ -64,6 +63,7 @@ import org.springframework.util.IdGenerator; * @author Arjen Poutsma * @author Mark Fisher * @author Gary Russell + * @author Juergen Hoeller * @since 4.0 * @see org.springframework.messaging.support.MessageBuilder * @see org.springframework.messaging.support.MessageHeaderAccessor @@ -140,6 +140,21 @@ public class MessageHeaders implements Map, Serializable { } } + /** + * Copy constructor which allows for ignoring certain entries. + * Used for serialization without non-serializable entries. + * @param original the MessageHeaders to copy + * @param keysToIgnore the keys of the entries to ignore + */ + private MessageHeaders(MessageHeaders original, Set keysToIgnore) { + this.headers = new HashMap(original.headers.size() - keysToIgnore.size()); + for (Map.Entry entry : original.headers.entrySet()) { + if (!keysToIgnore.contains(entry.getKey())) { + this.headers.put(entry.getKey(), entry.getValue()); + } + } + } + protected Map getRawHeaders() { return this.headers; @@ -165,6 +180,7 @@ public class MessageHeaders implements Map, Serializable { return get(ERROR_CHANNEL); } + @SuppressWarnings("unchecked") public T get(Object key, Class type) { Object value = this.headers.get(key); @@ -179,23 +195,6 @@ public class MessageHeaders implements Map, Serializable { } - @Override - public boolean equals(Object other) { - return (this == other || - (other instanceof MessageHeaders && this.headers.equals(((MessageHeaders) other).headers))); - } - - @Override - public int hashCode() { - return this.headers.hashCode(); - } - - @Override - public String toString() { - return this.headers.toString(); - } - - // Delegating Map implementation public boolean containsKey(Object key) { @@ -269,23 +268,47 @@ public class MessageHeaders implements Map, Serializable { // Serialization methods private void writeObject(ObjectOutputStream out) throws IOException { - List keysToRemove = new ArrayList(); + Set keysToIgnore = new HashSet(); for (Map.Entry entry : this.headers.entrySet()) { if (!(entry.getValue() instanceof Serializable)) { - keysToRemove.add(entry.getKey()); + keysToIgnore.add(entry.getKey()); } } - for (String key : keysToRemove) { - if (logger.isInfoEnabled()) { - logger.info("Removing non-serializable header: " + key); + + if (keysToIgnore.isEmpty()) { + // All entries are serializable -> serialize the regular MessageHeaders instance + out.defaultWriteObject(); + } + else { + // Some non-serializable entries -> serialize a temporary MessageHeaders copy + if (logger.isDebugEnabled()) { + logger.debug("Ignoring non-serializable message headers: " + keysToIgnore); } - this.headers.remove(key); + out.writeObject(new MessageHeaders(this, keysToIgnore)); } - out.defaultWriteObject(); } private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { in.defaultReadObject(); } + + // equals, hashCode, toString + + @Override + public boolean equals(Object other) { + return (this == other || + (other instanceof MessageHeaders && this.headers.equals(((MessageHeaders) other).headers))); + } + + @Override + public int hashCode() { + return this.headers.hashCode(); + } + + @Override + public String toString() { + return this.headers.toString(); + } + } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/MessageHeadersTests.java b/spring-messaging/src/test/java/org/springframework/messaging/MessageHeadersTests.java index 0e557cce6ab..cd25691b43e 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/MessageHeadersTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/MessageHeadersTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,6 @@ package org.springframework.messaging; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -29,6 +25,8 @@ import java.util.concurrent.atomic.AtomicLong; import org.junit.Test; +import org.springframework.util.SerializationTestUtils; + import static org.junit.Assert.*; /** @@ -36,10 +34,10 @@ import static org.junit.Assert.*; * * @author Rossen Stoyanchev * @author Gary Russell + * @author Juergen Hoeller */ public class MessageHeadersTests { - @Test public void testTimestamp() { MessageHeaders headers = new MessageHeaders(null); @@ -164,9 +162,11 @@ public class MessageHeadersTests { map.put("name", "joe"); map.put("age", 42); MessageHeaders input = new MessageHeaders(map); - MessageHeaders output = (MessageHeaders) serializeAndDeserialize(input); + MessageHeaders output = (MessageHeaders) SerializationTestUtils.serializeAndDeserialize(input); assertEquals("joe", output.get("name")); assertEquals(42, output.get("age")); + assertEquals("joe", input.get("name")); + assertEquals(42, input.get("age")); } @Test @@ -176,37 +176,25 @@ public class MessageHeadersTests { map.put("name", "joe"); map.put("address", address); MessageHeaders input = new MessageHeaders(map); - MessageHeaders output = (MessageHeaders) serializeAndDeserialize(input); + MessageHeaders output = (MessageHeaders) SerializationTestUtils.serializeAndDeserialize(input); assertEquals("joe", output.get("name")); assertNull(output.get("address")); + assertEquals("joe", input.get("name")); + assertSame(address, input.get("address")); } @Test - public void subClassWithCustomIdAndNoTimestamp() { + public void subclassWithCustomIdAndNoTimestamp() { final AtomicLong id = new AtomicLong(); @SuppressWarnings("serial") class MyMH extends MessageHeaders { - public MyMH() { super(null, new UUID(0, id.incrementAndGet()), -1L); } - } MessageHeaders headers = new MyMH(); assertEquals("00000000-0000-0000-0000-000000000001", headers.getId().toString()); assertEquals(1, headers.size()); } - private static Object serializeAndDeserialize(Object object) throws Exception { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ObjectOutputStream out = new ObjectOutputStream(baos); - out.writeObject(object); - out.close(); - ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); - ObjectInputStream in = new ObjectInputStream(bais); - Object result = in.readObject(); - in.close(); - return result; - } - }