Browse Source

Allow Protobuf codec extensions

Closes gh-35403
pull/35684/head
rstoyanchev 2 months ago
parent
commit
9331e1e86c
  1. 112
      spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java
  2. 26
      spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufEncoder.java
  3. 41
      spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufHttpMessageWriter.java

112
spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufDecoder.java

@ -129,13 +129,22 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
public Flux<Message> decode(Publisher<DataBuffer> inputStream, ResolvableType elementType, public Flux<Message> decode(Publisher<DataBuffer> inputStream, ResolvableType elementType,
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) { @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
MessageDecoderFunction decoderFunction = new MessageDecoderFunction(elementType, this.maxMessageSize); MessageDecoderFunction decoderFunction =
new MessageDecoderFunction(elementType, this.maxMessageSize, initMessageSizeReader());
return Flux.from(inputStream) return Flux.from(inputStream)
.flatMapIterable(decoderFunction) .flatMapIterable(decoderFunction)
.doOnTerminate(decoderFunction::discard); .doOnTerminate(decoderFunction::discard);
} }
/**
* Return a reader for message size information encoded in the input stream.
* @since 7.0
*/
protected MessageSizeReader initMessageSizeReader() {
return new DefaultMessageSizeReader();
}
@Override @Override
public Mono<Message> decodeToMono(Publisher<DataBuffer> inputStream, ResolvableType elementType, public Mono<Message> decodeToMono(Publisher<DataBuffer> inputStream, ResolvableType elementType,
@Nullable MimeType mimeType, @Nullable Map<String, Object> hints) { @Nullable MimeType mimeType, @Nullable Map<String, Object> hints) {
@ -150,9 +159,7 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
try { try {
Message.Builder builder = getMessageBuilder(targetType.toClass()); Message.Builder builder = getMessageBuilder(targetType.toClass());
ByteBuffer byteBuffer = ByteBuffer.allocate(dataBuffer.readableByteCount()); merge(dataBuffer, builder);
dataBuffer.toByteBuffer(byteBuffer);
builder.mergeFrom(CodedInputStream.newInstance(byteBuffer), this.extensionRegistry);
return builder.build(); return builder.build();
} }
catch (IOException ex) { catch (IOException ex) {
@ -166,6 +173,17 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
} }
} }
/**
* Use merge methods on {@link Message.Builder} to read a single message
* from the given {@code DataBuffer}.
* @since 7.0
*/
protected void merge(DataBuffer dataBuffer, Message.Builder builder) throws IOException {
ByteBuffer byteBuffer = ByteBuffer.allocate(dataBuffer.readableByteCount());
dataBuffer.toByteBuffer(byteBuffer);
builder.mergeFrom(CodedInputStream.newInstance(byteBuffer), this.extensionRegistry);
}
/** /**
* Create a new {@code Message.Builder} instance for the given class. * Create a new {@code Message.Builder} instance for the given class.
@ -196,15 +214,14 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
private int messageBytesToRead; private int messageBytesToRead;
private int offset; private final MessageSizeReader messageSizeReader;
public MessageDecoderFunction(ResolvableType elementType, int maxMessageSize) { public MessageDecoderFunction(ResolvableType elementType, int maxMessageSize, MessageSizeReader messageSizeReader) {
this.elementType = elementType; this.elementType = elementType;
this.maxMessageSize = maxMessageSize; this.maxMessageSize = maxMessageSize;
this.messageSizeReader = messageSizeReader;
} }
@Override @Override
public Iterable<? extends Message> apply(DataBuffer input) { public Iterable<? extends Message> apply(DataBuffer input) {
try { try {
@ -214,9 +231,11 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
do { do {
if (this.output == null) { if (this.output == null) {
if (!readMessageSize(input)) { Integer messageSize = this.messageSizeReader.readMessageSize(input);
if (messageSize == null) {
return messages; return messages;
} }
this.messageBytesToRead = messageSize;
if (this.maxMessageSize > 0 && this.messageBytesToRead > this.maxMessageSize) { if (this.maxMessageSize > 0 && this.messageBytesToRead > this.maxMessageSize) {
throw new DataBufferLimitException( throw new DataBufferLimitException(
"The number of bytes to read for message " + "The number of bytes to read for message " +
@ -262,60 +281,89 @@ public class ProtobufDecoder extends ProtobufCodecSupport implements Decoder<Mes
} }
} }
public void discard() {
if (this.output != null) {
DataBufferUtils.release(this.output);
}
}
}
/**
* Component to read the size of a message. Implementations must be
* stateful and expect size information is potentially split
* across input chunks.
* @since 7.0
*/
protected interface MessageSizeReader {
/** /**
* Parse message size as a varint from the input stream, updating {@code messageBytesToRead} and * Read the message size from the given buffer. This method may be
* {@code offset} fields if needed to allow processing of upcoming chunks. * called multiple times before the message size is fully read.
* Inspired from {@link CodedInputStream#readRawVarint32(int, java.io.InputStream)} * @return return the message size, or {@code null} if the data in the
* @return {@code true} when the message size is parsed successfully, {@code false} when the message size is * input buffer was insufficient
* truncated
* @see <a href="https://developers.google.com/protocol-buffers/docs/encoding#varints">Base 128 Varints</a>
*/ */
private boolean readMessageSize(DataBuffer input) { @Nullable Integer readMessageSize(DataBuffer input);
}
/**
* Default reader for Protobuf messages.
* <p>Parses the message size as a varint from the input stream.
* Inspired by {@link CodedInputStream#readRawVarint32(int, java.io.InputStream)},
* @see <a href="https://developers.google.com/protocol-buffers/docs/encoding#varints">Base 128 Varints</a>
*/
private static class DefaultMessageSizeReader implements MessageSizeReader {
private int offset;
private int messageSize;
@Override
public @Nullable Integer readMessageSize(DataBuffer input) {
if (this.offset == 0) { if (this.offset == 0) {
if (input.readableByteCount() == 0) { if (input.readableByteCount() == 0) {
return false; return null;
} }
int firstByte = input.read(); int firstByte = input.read();
if ((firstByte & 0x80) == 0) { if ((firstByte & 0x80) == 0) {
this.messageBytesToRead = firstByte; this.messageSize = firstByte;
return true; return getAndReset();
} }
this.messageBytesToRead = firstByte & 0x7f; this.messageSize = firstByte & 0x7f;
this.offset = 7; this.offset = 7;
} }
if (this.offset < 32) { if (this.offset < 32) {
for (; this.offset < 32; this.offset += 7) { for (; this.offset < 32; this.offset += 7) {
if (input.readableByteCount() == 0) { if (input.readableByteCount() == 0) {
return false; return null;
} }
final int b = input.read(); final int b = input.read();
this.messageBytesToRead |= (b & 0x7f) << this.offset; this.messageSize |= (b & 0x7f) << this.offset;
if ((b & 0x80) == 0) { if ((b & 0x80) == 0) {
this.offset = 0; return getAndReset();
return true;
} }
} }
} }
// Keep reading up to 64 bits. // Keep reading up to 64 bits.
for (; this.offset < 64; this.offset += 7) { for (; this.offset < 64; this.offset += 7) {
if (input.readableByteCount() == 0) { if (input.readableByteCount() == 0) {
return false; return null;
} }
final int b = input.read(); final int b = input.read();
if ((b & 0x80) == 0) { if ((b & 0x80) == 0) {
this.offset = 0; return getAndReset();
return true;
} }
} }
this.offset = 0; getAndReset();
throw new DecodingException("Cannot parse message size: malformed varint"); throw new DecodingException("Cannot parse message size: malformed varint");
} }
public void discard() { private @Nullable Integer getAndReset() {
if (this.output != null) { Integer result = (this.messageSize != 0 ? this.messageSize : null);
DataBufferUtils.release(this.output); this.offset = 0;
} this.messageSize = 0;
return result;
} }
} }

26
spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufEncoder.java

@ -107,15 +107,10 @@ public class ProtobufEncoder extends ProtobufCodecSupport implements HttpMessage
} }
private DataBuffer encodeValue(Message message, DataBufferFactory bufferFactory, boolean delimited) { private DataBuffer encodeValue(Message message, DataBufferFactory bufferFactory, boolean delimited) {
FastByteArrayOutputStream bos = new FastByteArrayOutputStream(); FastByteArrayOutputStream outputStream = new FastByteArrayOutputStream();
try { try {
if (delimited) { writeMessage(message, delimited, outputStream);
message.writeDelimitedTo((OutputStream) bos); byte[] bytes = outputStream.toByteArrayUnsafe();
}
else {
message.writeTo((OutputStream) bos);
}
byte[] bytes = bos.toByteArrayUnsafe();
return bufferFactory.wrap(bytes); return bufferFactory.wrap(bytes);
} }
catch (IOException ex) { catch (IOException ex) {
@ -123,4 +118,19 @@ public class ProtobufEncoder extends ProtobufCodecSupport implements HttpMessage
} }
} }
/**
* Use write methods on {@link Message} to write to the given {@code OutputStream}.
* @since 7.0
*/
protected void writeMessage(
Message message, boolean delimited, OutputStream outputStream) throws IOException {
if (delimited) {
message.writeDelimitedTo(outputStream);
}
else {
message.writeTo(outputStream);
}
}
} }

41
spring-web/src/main/java/org/springframework/http/codec/protobuf/ProtobufHttpMessageWriter.java

@ -77,30 +77,47 @@ public class ProtobufHttpMessageWriter extends EncoderHttpMessageWriter<Message>
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Override @Override
public Mono<Void> write(Publisher<? extends Message> inputStream, ResolvableType elementType, public Mono<Void> write(Publisher<? extends Message> inputStream, ResolvableType elementType,
@Nullable MediaType mediaType, ReactiveHttpOutputMessage message, Map<String, Object> hints) { @Nullable MediaType mediaType, ReactiveHttpOutputMessage outputMessage, Map<String, Object> hints) {
try { try {
Message.Builder builder = getMessageBuilder(elementType.toClass()); Message.Builder builder = getMessageBuilder(elementType.toClass());
Descriptors.Descriptor descriptor = builder.getDescriptorForType(); Descriptors.Descriptor descriptor = builder.getDescriptorForType();
message.getHeaders().add(X_PROTOBUF_SCHEMA_HEADER, descriptor.getFile().getName()); outputMessage.getHeaders().add(X_PROTOBUF_SCHEMA_HEADER, descriptor.getFile().getName());
message.getHeaders().add(X_PROTOBUF_MESSAGE_HEADER, descriptor.getFullName()); outputMessage.getHeaders().add(X_PROTOBUF_MESSAGE_HEADER, descriptor.getFullName());
if (inputStream instanceof Flux) { if (inputStream instanceof Flux) {
if (mediaType == null) { outputMessage.getHeaders().setContentType(getStreamingContentType(mediaType));
message.getHeaders().setContentType(((HttpMessageEncoder<?>)getEncoder()).getStreamingMediaTypes().get(0));
}
else if (!ProtobufEncoder.DELIMITED_VALUE.equals(mediaType.getParameters().get(ProtobufEncoder.DELIMITED_KEY))) {
Map<String, String> parameters = new HashMap<>(mediaType.getParameters());
parameters.put(ProtobufEncoder.DELIMITED_KEY, ProtobufEncoder.DELIMITED_VALUE);
message.getHeaders().setContentType(new MediaType(mediaType.getType(), mediaType.getSubtype(), parameters));
}
} }
return super.write(inputStream, elementType, mediaType, message, hints); extendHeaders(outputMessage, hints);
return super.write(inputStream, elementType, mediaType, outputMessage, hints);
} }
catch (Exception ex) { catch (Exception ex) {
return Mono.error(new EncodingException("Could not write Protobuf message: " + ex.getMessage(), ex)); return Mono.error(new EncodingException("Could not write Protobuf message: " + ex.getMessage(), ex));
} }
} }
/**
* Return the {@code MediaType} to use when the input Publisher is multivalued.
* @since 7.0
*/
protected MediaType getStreamingContentType(@Nullable MediaType mediaType) {
if (mediaType == null) {
return ((HttpMessageEncoder<?>) getEncoder()).getStreamingMediaTypes().get(0);
}
Map<String, String> params = new HashMap<>(mediaType.getParameters());
if (!ProtobufEncoder.DELIMITED_VALUE.equals(params.get(ProtobufEncoder.DELIMITED_KEY))) {
params.put(ProtobufEncoder.DELIMITED_KEY, ProtobufEncoder.DELIMITED_VALUE);
mediaType = new MediaType(mediaType, params);
}
return mediaType;
}
/**
* Make further updates to headers.
* @since 7.0
*/
protected void extendHeaders(ReactiveHttpOutputMessage message, Map<String, Object> hints) {
}
/** /**
* Create a new {@code Message.Builder} instance for the given class. * Create a new {@code Message.Builder} instance for the given class.
* <p>This method uses a ConcurrentHashMap for caching method lookups. * <p>This method uses a ConcurrentHashMap for caching method lookups.

Loading…
Cancel
Save