Browse Source

Allow Protobuf codec extensions

Closes gh-35403
pull/35684/head
rstoyanchev 2 months ago
parent
commit
9331e1e86c
  1. 112
      spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java
  2. 26
      spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufEncoder.java
  3. 41
      spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufHttpMessageWriter.java

112
spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java

@ -129,13 +129,22 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes @@ -129,13 +129,22 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
public Flux<Message> decode(Publisher<DataBuffer> inputStream, ResolvableType elementType,
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
MessageDecoderFunction decoderFunction = new MessageDecoderFunction(elementType, this.maxMessageSize);
MessageDecoderFunction decoderFunction =
new MessageDecoderFunction(elementType, this.maxMessageSize, initMessageSizeReader());
return Flux.from(inputStream)
.flatMapIterable(decoderFunction)
.doOnTerminate(decoderFunction::discard);
}
/**
* Return a reader for message size information encoded in the input stream.
* @since 7.0
*/
protected MessageSizeReader initMessageSizeReader() {
return new DefaultMessageSizeReader();
}
@Override
public Mono<Message> decodeToMono(Publisher<DataBuffer> inputStream, ResolvableType elementType,
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
@ -150,9 +159,7 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes @@ -150,9 +159,7 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
try {
Message.Builder builder = getMessageBuilder(targetType.toClass());
ByteBuffer byteBuffer = ByteBuffer.allocate(dataBuffer.readableByteCount());
dataBuffer.toByteBuffer(byteBuffer);
builder.mergeFrom(CodedInputStream.newInstance(byteBuffer), this.extensionRegistry);
merge(dataBuffer, builder);
return builder.build();
}
catch (IOException ex) {
@ -166,6 +173,17 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes @@ -166,6 +173,17 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
}
}
/**
* Use merge methods on {@link Message.Builder} to read a single message
* from the given {@code DataBuffer}.
* @since 7.0
*/
protected void merge(DataBuffer dataBuffer, Message.Builder builder) throws IOException {
ByteBuffer byteBuffer = ByteBuffer.allocate(dataBuffer.readableByteCount());
dataBuffer.toByteBuffer(byteBuffer);
builder.mergeFrom(CodedInputStream.newInstance(byteBuffer), this.extensionRegistry);
}
/**
* Create a new {@code Message.Builder} instance for the given class.
@ -196,15 +214,14 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes @@ -196,15 +214,14 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
private int messageBytesToRead;
private int offset;
private final MessageSizeReader messageSizeReader;
public MessageDecoderFunction(ResolvableType elementType, int maxMessageSize) {
public MessageDecoderFunction(ResolvableType elementType, int maxMessageSize, MessageSizeReader messageSizeReader) {
this.elementType = elementType;
this.maxMessageSize = maxMessageSize;
this.messageSizeReader = messageSizeReader;
}
@Override
public Iterable<? extends Message> apply(DataBuffer input) {
try {
@ -214,9 +231,11 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes @@ -214,9 +231,11 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
do {
if (this.output == null) {
if (!readMessageSize(input)) {
Integer messageSize = this.messageSizeReader.readMessageSize(input);
if (messageSize == null) {
return messages;
}
this.messageBytesToRead = messageSize;
if (this.maxMessageSize > 0 && this.messageBytesToRead > this.maxMessageSize) {
throw new DataBufferLimitException(
"The number of bytes to read for message " +
@ -262,60 +281,89 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes @@ -262,60 +281,89 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
}
}
public void discard() {
if (this.output != null) {
DataBufferUtils.release(this.output);
}
}
}
/**
* Component to read the size of a message. Implementations must be
* stateful and expect size information is potentially split
* across input chunks.
* @since 7.0
*/
protected interface MessageSizeReader {
/**
* Parse message size as a varint from the input stream, updating {@code messageBytesToRead} and
* {@code offset} fields if needed to allow processing of upcoming chunks.
* Inspired from {@link CodedInputStream#readRawVarint32(int, java.io.InputStream)}
* @return {@code true} when the message size is parsed successfully, {@code false} when the message size is
* truncated
* @see <a href="https://developers.google.com/protocol-buffers/docs/encoding#varints">Base 128 Varints</a>
* Read the message size from the given buffer. This method may be
* called multiple times before the message size is fully read.
* @return return the message size, or {@code null} if the data in the
* input buffer was insufficient
*/
private boolean readMessageSize(DataBuffer input) {
@Nullable Integer readMessageSize(DataBuffer input);
}
/**
* Default reader for Protobuf messages.
* <p>Parses the message size as a varint from the input stream.
* Inspired by {@link CodedInputStream#readRawVarint32(int, java.io.InputStream)},
* @see <a href="https://developers.google.com/protocol-buffers/docs/encoding#varints">Base 128 Varints</a>
*/
private static class DefaultMessageSizeReader implements MessageSizeReader {
private int offset;
private int messageSize;
@Override
public @Nullable Integer readMessageSize(DataBuffer input) {
if (this.offset == 0) {
if (input.readableByteCount() == 0) {
return false;
return null;
}
int firstByte = input.read();
if ((firstByte & 0x80) == 0) {
this.messageBytesToRead = firstByte;
return true;
this.messageSize = firstByte;
return getAndReset();
}
this.messageBytesToRead = firstByte & 0x7f;
this.messageSize = firstByte & 0x7f;
this.offset = 7;
}
if (this.offset < 32) {
for (; this.offset < 32; this.offset += 7) {
if (input.readableByteCount() == 0) {
return false;
return null;
}
final int b = input.read();
this.messageBytesToRead |= (b & 0x7f) << this.offset;
this.messageSize |= (b & 0x7f) << this.offset;
if ((b & 0x80) == 0) {
this.offset = 0;
return true;
return getAndReset();
}
}
}
// Keep reading up to 64 bits.
for (; this.offset < 64; this.offset += 7) {
if (input.readableByteCount() == 0) {
return false;
return null;
}
final int b = input.read();
if ((b & 0x80) == 0) {
this.offset = 0;
return true;
return getAndReset();
}
}
this.offset = 0;
getAndReset();
throw new DecodingException("Cannot parse message size: malformed varint");
}
public void discard() {
if (this.output != null) {
DataBufferUtils.release(this.output);
}
private @Nullable Integer getAndReset() {
Integer result = (this.messageSize != 0 ? this.messageSize : null);
this.offset = 0;
this.messageSize = 0;
return result;
}
}

26
spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufEncoder.java

@ -107,15 +107,10 @@ public class ProtobufEncoder extends ProtobufCodecSupport implements HttpMessage @@ -107,15 +107,10 @@ public class ProtobufEncoder extends ProtobufCodecSupport implements HttpMessage
}
private DataBuffer encodeValue(Message message, DataBufferFactory bufferFactory, boolean delimited) {
FastByteArrayOutputStream bos = new FastByteArrayOutputStream();
FastByteArrayOutputStream outputStream = new FastByteArrayOutputStream();
try {
if (delimited) {
message.writeDelimitedTo((OutputStream) bos);
}
else {
message.writeTo((OutputStream) bos);
}
byte[] bytes = bos.toByteArrayUnsafe();
writeMessage(message, delimited, outputStream);
byte[] bytes = outputStream.toByteArrayUnsafe();
return bufferFactory.wrap(bytes);
}
catch (IOException ex) {
@ -123,4 +118,19 @@ public class ProtobufEncoder extends ProtobufCodecSupport implements HttpMessage @@ -123,4 +118,19 @@ public class ProtobufEncoder extends ProtobufCodecSupport implements HttpMessage
}
}
/**
* Use write methods on {@link Message} to write to the given {@code OutputStream}.
* @since 7.0
*/
protected void writeMessage(
Message message, boolean delimited, OutputStream outputStream) throws IOException {
if (delimited) {
message.writeDelimitedTo(outputStream);
}
else {
message.writeTo(outputStream);
}
}
}

41
spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufHttpMessageWriter.java

@ -77,30 +77,47 @@ public class ProtobufHttpMessageWriter extends EncoderHttpMessageWriter<Message> @@ -77,30 +77,47 @@ public class ProtobufHttpMessageWriter extends EncoderHttpMessageWriter<Message>
@SuppressWarnings("unchecked")
@Override
public Mono<Void> write(Publisher<? extends Message> inputStream, ResolvableType elementType,
@Nullable MediaType mediaType, ReactiveHttpOutputMessage message, Map<String, Object> hints) {
@Nullable MediaType mediaType, ReactiveHttpOutputMessage outputMessage, Map<String, Object> hints) {
try {
Message.Builder builder = getMessageBuilder(elementType.toClass());
Descriptors.Descriptor descriptor = builder.getDescriptorForType();
message.getHeaders().add(X_PROTOBUF_SCHEMA_HEADER, descriptor.getFile().getName());
message.getHeaders().add(X_PROTOBUF_MESSAGE_HEADER, descriptor.getFullName());
outputMessage.getHeaders().add(X_PROTOBUF_SCHEMA_HEADER, descriptor.getFile().getName());
outputMessage.getHeaders().add(X_PROTOBUF_MESSAGE_HEADER, descriptor.getFullName());
if (inputStream instanceof Flux) {
if (mediaType == null) {
message.getHeaders().setContentType(((HttpMessageEncoder<?>)getEncoder()).getStreamingMediaTypes().get(0));
}
else if (!ProtobufEncoder.DELIMITED_VALUE.equals(mediaType.getParameters().get(ProtobufEncoder.DELIMITED_KEY))) {
Map<String, String> parameters = new HashMap<>(mediaType.getParameters());
parameters.put(ProtobufEncoder.DELIMITED_KEY, ProtobufEncoder.DELIMITED_VALUE);
message.getHeaders().setContentType(new MediaType(mediaType.getType(), mediaType.getSubtype(), parameters));
}
outputMessage.getHeaders().setContentType(getStreamingContentType(mediaType));
}
return super.write(inputStream, elementType, mediaType, message, hints);
extendHeaders(outputMessage, hints);
return super.write(inputStream, elementType, mediaType, outputMessage, hints);
}
catch (Exception ex) {
return Mono.error(new EncodingException("Could not write Protobuf message: " + ex.getMessage(), ex));
}
}
/**
* Return the {@code MediaType} to use when the input Publisher is multivalued.
* @since 7.0
*/
protected MediaType getStreamingContentType(@Nullable MediaType mediaType) {
if (mediaType == null) {
return ((HttpMessageEncoder<?>) getEncoder()).getStreamingMediaTypes().get(0);
}
Map<String, String> params = new HashMap<>(mediaType.getParameters());
if (!ProtobufEncoder.DELIMITED_VALUE.equals(params.get(ProtobufEncoder.DELIMITED_KEY))) {
params.put(ProtobufEncoder.DELIMITED_KEY, ProtobufEncoder.DELIMITED_VALUE);
mediaType = new MediaType(mediaType, params);
}
return mediaType;
}
/**
* Make further updates to headers.
* @since 7.0
*/
protected void extendHeaders(ReactiveHttpOutputMessage message, Map<String, Object> hints) {
}
/**
* Create a new {@code Message.Builder} instance for the given class.
* <p>This method uses a ConcurrentHashMap for caching method lookups.

Loading…
Cancel
Save