You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2018/07/29 18:15:45 UTC
[arrow] branch master updated: ARROW-2704: [Java] Change
MessageReader API to improve custom message handling for streams
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new cd162f5 ARROW-2704: [Java] Change MessageReader API to improve custom message handling for streams
cd162f5 is described below
commit cd162f59847772925e731f2bcd78c87bd078f9d8
Author: Bryan Cutler <cu...@gmail.com>
AuthorDate: Sun Jul 29 14:15:38 2018 -0400
ARROW-2704: [Java] Change MessageReader API to improve custom message handling for streams
Made a number of improvements to make stream processing for messages easier without having to load Arrow record and dictionary batches
### Changes Made
- Changed `MessageReader` interface to have a `readNext` method that takes an extension of `MessageHolder` to return the message data. See details below.
- Added static function in ArrowStreamWriter to write the EOS identifier
- Moved intToBytes to MessageSerializer (where bytesToInt is), and now works with an existing byte array
- Removed the abstract identifier in beginInternal and endInternal of `ArrowReader` so that subclasses are not forced to implement if they are not needed.
- Made deserialze* functions in `MessageSerializer` more consistent. In general, each can deserialze from a `ReadChannel` or from data directly.
### MessageReader API
The changes to `MessageReader` are to make it easy to subclass and implement custom handling of reading messages/data. `readNext` input is a `MessageHolder` instance which can be extended to store message data and additional information during the read.
### Usage in ArrowStreamReader
There is one implementation of `MessageReader`, `MessageChannelReader` which reads the body data into an offheap `ArrowBuf`. `ArrowStreamReader` uses this implementation to do the message reading and then handles the deserialization of batch data into vectors.
### Testing
Updated existing tests for new APIs, added test for write/read of zero-length record batch and for message alignment.
Author: Bryan Cutler <cu...@gmail.com>
Closes #2139 from BryanCutler/java-stream-low-level-ARROW-2704 and squashes the following commits:
13861329 <Bryan Cutler> removed function in MessageSerializer that's not totally necessary
14eba720 <Bryan Cutler> removed MessageReader interface, MessageChannelHolder
53629862 <Bryan Cutler> fixed some comments
f2fca5cf <Bryan Cutler> Added MessageChannelResult to return from readMessage, and simplified MessageChannelReader to implement MessageReader directly
d109c434 <Bryan Cutler> Made MessageReader autoclosable and throws generic exception
36dcc202 <Bryan Cutler> updated javadoc wording
fab31420 <Bryan Cutler> Fixed up some MessageSerializer docs
a2169e01 <Bryan Cutler> remove accidental imports
b4f9d6eb <Bryan Cutler> added test for message alignment
68246ff2 <Bryan Cutler> Separated out message and prefix writing functions
0f0a6075 <Bryan Cutler> revert accidental change to integration tests made during testing
e6837e50 <Bryan Cutler> fixed bug when reading zero-length batches, added test
7bed8503 <Bryan Cutler> Changed MessageReader to abstract class with MessageReadHolder to hold data read
acdf50e2 <Bryan Cutler> Made readMessageLength and loadMessage static functions
1b78de9c <Bryan Cutler> made MessageChannelReader more friendly to low-level msg processing, added static method to write EOS, removed abstract from beginInternal and endInternal, made intToBytes more consistent and added tests
---
.../apache/arrow/vector/ipc/ArrowStreamReader.java | 62 +++--
.../apache/arrow/vector/ipc/ArrowStreamWriter.java | 31 ++-
.../org/apache/arrow/vector/ipc/ArrowWriter.java | 9 +-
.../org/apache/arrow/vector/ipc/WriteChannel.java | 14 +-
.../vector/ipc/message/MessageChannelReader.java | 72 ++----
.../vector/ipc/message/MessageChannelResult.java | 104 ++++++++
.../arrow/vector/ipc/message/MessageHolder.java | 30 +++
.../arrow/vector/ipc/message/MessageReader.java | 65 -----
.../vector/ipc/message/MessageSerializer.java | 274 +++++++++++++++------
.../arrow/vector/ipc/MessageSerializerTest.java | 53 ++++
.../apache/arrow/vector/ipc/TestArrowStream.java | 26 ++
11 files changed, 508 insertions(+), 232 deletions(-)
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamReader.java
index d1e4802..74c6074 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamReader.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamReader.java
@@ -23,32 +23,29 @@ import java.io.InputStream;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
-import org.apache.arrow.flatbuf.Message;
import org.apache.arrow.flatbuf.MessageHeader;
import org.apache.arrow.memory.BufferAllocator;
-import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.MessageChannelReader;
-import org.apache.arrow.vector.ipc.message.MessageReader;
+import org.apache.arrow.vector.ipc.message.MessageHolder;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
-import org.apache.arrow.vector.ipc.ReadChannel;
import org.apache.arrow.vector.types.pojo.Schema;
/**
- * This classes reads from an input stream and produces ArrowRecordBatches.
+ * This class reads from an input stream and produces ArrowRecordBatches.
*/
public class ArrowStreamReader extends ArrowReader {
- private MessageReader messageReader;
+ private MessageChannelReader messageReader;
/**
- * Constructs a streaming reader using the MessageReader interface. Non-blocking.
+ * Constructs a streaming reader using a MessageChannelReader. Non-blocking.
*
- * @param messageReader interface to get read messages
+ * @param messageReader reader used to get messages from a ReadChannel
* @param allocator to allocate new buffers
*/
- public ArrowStreamReader(MessageReader messageReader, BufferAllocator allocator) {
+ public ArrowStreamReader(MessageChannelReader messageReader, BufferAllocator allocator) {
super(allocator);
this.messageReader = messageReader;
}
@@ -60,7 +57,7 @@ public class ArrowStreamReader extends ArrowReader {
* @param allocator to allocate new buffers
*/
public ArrowStreamReader(ReadableByteChannel in, BufferAllocator allocator) {
- this(new MessageChannelReader(new ReadChannel(in)), allocator);
+ this(new MessageChannelReader(new ReadChannel(in), allocator), allocator);
}
/**
@@ -101,19 +98,23 @@ public class ArrowStreamReader extends ArrowReader {
*/
public boolean loadNextBatch() throws IOException {
prepareLoadNextBatch();
-
- Message message = messageReader.readNextMessage();
+ MessageHolder holder = new MessageHolder();
// Reached EOS
- if (message == null) {
+ if (!messageReader.readNext(holder)) {
return false;
}
- if (message.headerType() != MessageHeader.RecordBatch) {
- throw new IOException("Expected RecordBatch but header was " + message.headerType());
+ if (holder.message.headerType() != MessageHeader.RecordBatch) {
+ throw new IOException("Expected RecordBatch but header was " + holder.message.headerType());
+ }
+
+ // For zero-length batches, need an empty buffer to deserialize the batch
+ if (holder.bodyBuffer == null) {
+ holder.bodyBuffer = allocator.getEmpty();
}
- ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(messageReader, message, allocator);
+ ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(holder.message, holder.bodyBuffer);
loadRecordBatch(batch);
return true;
}
@@ -125,7 +126,17 @@ public class ArrowStreamReader extends ArrowReader {
*/
@Override
protected Schema readSchema() throws IOException {
- return MessageSerializer.deserializeSchema(messageReader);
+ MessageHolder holder = new MessageHolder();
+
+ if (!messageReader.readNext(holder)) {
+ throw new IOException("Unexpected end of input. Missing schema.");
+ }
+
+ if (holder.message.headerType() != MessageHeader.Schema) {
+ throw new IOException("Expected schema but header was " + holder.message.headerType());
+ }
+
+ return MessageSerializer.deserializeSchema(holder.message);
}
/**
@@ -137,12 +148,21 @@ public class ArrowStreamReader extends ArrowReader {
*/
@Override
protected ArrowDictionaryBatch readDictionary() throws IOException {
- Message message = messageReader.readNextMessage();
+ MessageHolder holder = new MessageHolder();
+
+ if (!messageReader.readNext(holder)) {
+ throw new IOException("Unexpected end of input. Expected DictionaryBatch");
+ }
+
+ if (holder.message.headerType() != MessageHeader.DictionaryBatch) {
+ throw new IOException("Expected DictionaryBatch but header was " + holder.message.headerType());
+ }
- if (message.headerType() != MessageHeader.DictionaryBatch) {
- throw new IOException("Expected DictionaryBatch but header was " + message.headerType());
+ // For zero-length batches, need an empty buffer to deserialize the batch
+ if (holder.bodyBuffer == null) {
+ holder.bodyBuffer = allocator.getEmpty();
}
- return MessageSerializer.deserializeDictionaryBatch(messageReader, message, allocator);
+ return MessageSerializer.deserializeDictionaryBatch(holder.message, holder.bodyBuffer);
}
}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java
index 784ce08..06439ce 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java
@@ -26,22 +26,47 @@ import java.io.OutputStream;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
+/**
+ * Writer for the Arrow stream format to send ArrowRecordBatches over a WriteChannel
+ */
public class ArrowStreamWriter extends ArrowWriter {
+ /**
+ * Construct an ArrowStreamWriter with an optional DictionaryProvider for the OutputStream.
+ *
+ * @param root Existing VectorSchemaRoot with vectors to be written.
+ * @param provider DictionaryProvider for any vectors that are dictionary encoded.
+ * (Optional, can be null)
+ * @param out OutputStream for writing.
+ */
public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, OutputStream out) {
this(root, provider, Channels.newChannel(out));
}
+ /**
+ * Construct an ArrowStreamWriter with an optional DictionaryProvider for the WritableByteChannel.
+ *
+ * @param root Existing VectorSchemaRoot with vectors to be written.
+ * @param provider DictionaryProvider for any vectors that are dictionary encoded.
+ * (Optional, can be null)
+ * @param out WritableByteChannel for writing.
+ */
public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
super(root, provider, out);
}
- @Override
- protected void startInternal(WriteChannel out) throws IOException {
+ /**
+ * Write an EOS identifier to the WriteChannel.
+ *
+ * @param out Open WriteChannel with an active Arrow stream.
+ * @throws IOException
+ */
+ public static void writeEndOfStream(WriteChannel out) throws IOException {
+ out.writeIntLittleEndian(0);
}
@Override
protected void endInternal(WriteChannel out) throws IOException {
- out.writeIntLittleEndian(0);
+ writeEndOfStream(out);
}
}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
index 8bc6402..93f2521 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
@@ -42,6 +42,9 @@ import org.slf4j.LoggerFactory;
import com.google.common.collect.ImmutableList;
+/**
+ * Abstract base class for implementing Arrow writers for IPC over a WriteChannel
+ */
public abstract class ArrowWriter implements AutoCloseable {
protected static final Logger LOGGER = LoggerFactory.getLogger(ArrowWriter.class);
@@ -149,9 +152,11 @@ public abstract class ArrowWriter implements AutoCloseable {
}
}
- protected abstract void startInternal(WriteChannel out) throws IOException;
+ protected void startInternal(WriteChannel out) throws IOException {
+ }
- protected abstract void endInternal(WriteChannel out) throws IOException;
+ protected void endInternal(WriteChannel out) throws IOException {
+ }
@Override
public void close() {
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/WriteChannel.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/WriteChannel.java
index da500aa..36e8320 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/WriteChannel.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/WriteChannel.java
@@ -26,6 +26,7 @@ import com.google.flatbuffers.FlatBufferBuilder;
import io.netty.buffer.ArrowBuf;
import org.apache.arrow.vector.ipc.message.FBSerializable;
+import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -76,17 +77,10 @@ public class WriteChannel implements AutoCloseable {
return length;
}
- public static byte[] intToBytes(int value) {
- byte[] outBuffer = new byte[4];
- outBuffer[3] = (byte) (value >>> 24);
- outBuffer[2] = (byte) (value >>> 16);
- outBuffer[1] = (byte) (value >>> 8);
- outBuffer[0] = (byte) (value >>> 0);
- return outBuffer;
- }
-
public long writeIntLittleEndian(int v) throws IOException {
- return write(intToBytes(v));
+ byte[] outBuffer = new byte[4];
+ MessageSerializer.intToBytes(v, outBuffer);
+ return write(outBuffer);
}
public void write(ArrowBuf buffer) throws IOException {
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageChannelReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageChannelReader.java
index 5bc3e1f..399f225 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageChannelReader.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageChannelReader.java
@@ -18,79 +18,52 @@
package org.apache.arrow.vector.ipc.message;
+import java.io.IOException;
-import io.netty.buffer.ArrowBuf;
-import org.apache.arrow.flatbuf.Message;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.ipc.ReadChannel;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-
/**
* Reads a sequence of messages using a ReadChannel.
*/
-public class MessageChannelReader implements MessageReader {
-
- private ReadChannel in;
+public class MessageChannelReader implements AutoCloseable {
+ protected ReadChannel in;
+ protected BufferAllocator allocator;
/**
- * Construct from an existing ReadChannel.
+ * Construct a MessageReader to read streaming messages from an existing ReadChannel.
*
* @param in Channel to read messages from
+ * @param allocator BufferAllocator used to read Message body into an ArrowBuf.
*/
- public MessageChannelReader(ReadChannel in) {
+ public MessageChannelReader(ReadChannel in, BufferAllocator allocator) {
this.in = in;
+ this.allocator = allocator;
}
/**
- * Read the next message from the ReadChannel.
+ * Read a Message from the ReadChannel and populate holder if a valid message was read.
*
- * @return A Message or null if ReadChannel has no more messages, indicated by message length of 0
+ * @param holder Message and message information that is populated when read by implementation
+ * @return true if a valid Message was read, false if end-of-stream
* @throws IOException
*/
- @Override
- public Message readNextMessage() throws IOException {
- // Read the message size. There is an i32 little endian prefix.
- ByteBuffer buffer = ByteBuffer.allocate(4);
- if (in.readFully(buffer) != 4) {
- return null;
- }
- int messageLength = MessageSerializer.bytesToInt(buffer.array());
- if (messageLength == 0) {
- return null;
- }
+ public boolean readNext(MessageHolder holder) throws IOException {
- buffer = ByteBuffer.allocate(messageLength);
- if (in.readFully(buffer) != messageLength) {
- throw new IOException(
- "Unexpected end of stream trying to read message.");
+ // Read the flatbuf message and check for end-of-stream
+ MessageChannelResult result = MessageSerializer.readMessage(in);
+ if (!result.hasMessage()) {
+ return false;
}
- buffer.rewind();
-
- return Message.getRootAsMessage(buffer);
- }
+ holder.message = result.getMessage();
- /**
- * Read a message body from the ReadChannel.
- *
- * @param message Read message that is followed by a body of data
- * @param allocator BufferAllocator to allocate memory for body data
- * @return ArrowBuf containing the message body data
- * @throws IOException
- */
- @Override
- public ArrowBuf readMessageBody(Message message, BufferAllocator allocator) throws IOException {
-
- int bodyLength = (int) message.bodyLength();
-
- // Now read the record batch body
- ArrowBuf buffer = allocator.buffer(bodyLength);
- if (in.readFully(buffer, bodyLength) != bodyLength) {
- throw new IOException("Unexpected end of input trying to read batch.");
+ // Read message body data if defined in message
+ if (result.messageHasBody()) {
+ int bodyLength = (int) result.getMessageBodyLength();
+ holder.bodyBuffer = MessageSerializer.readMessageBody(in, bodyLength, allocator);
}
- return buffer;
+ return true;
}
/**
@@ -98,7 +71,6 @@ public class MessageChannelReader implements MessageReader {
*
* @return number of bytes
*/
- @Override
public long bytesRead() {
return in.bytesRead();
}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageChannelResult.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageChannelResult.java
new file mode 100644
index 0000000..0b732f2
--- /dev/null
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageChannelResult.java
@@ -0,0 +1,104 @@
+/**
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.arrow.vector.ipc.message;
+
+import java.nio.ByteBuffer;
+
+import org.apache.arrow.flatbuf.Message;
+
+/**
+* Class to hold resulting Message and message information when reading messages from a ReadChannel.
+*/
+public class MessageChannelResult {
+
+ /**
+ * Construct a container to hold a message result.
+ *
+ * @param messageLength the length of the message read in bytes
+ * @param messageBuffer contains the raw bytes of the message
+ * @param message the realized flatbuf Message
+ */
+ public MessageChannelResult(int messageLength, ByteBuffer messageBuffer, Message message) {
+ this.messageLength = messageLength;
+ this.messageBuffer = messageBuffer;
+ this.message = message;
+ }
+
+ /**
+ * Returns status indicating if the MessageResult has a valid message.
+ *
+ * @return true if the result contains a valid message
+ */
+ public boolean hasMessage() {
+ return message != null;
+ }
+
+ /**
+ * Get the length of the message in bytes.
+ *
+ * @return number of bytes in the message buffer.
+ */
+ public int getMessageLength() {
+ return messageLength;
+ }
+
+ /**
+ * Get the buffer containing the raw message bytes.
+ *
+ * @return buffer containing the message
+ */
+ public ByteBuffer getMessageBuffer() {
+ return messageBuffer;
+ }
+
+ /**
+ * Check if the message is valid and is followed by a body.
+ *
+ * @return true if message has a body
+ */
+ public boolean messageHasBody() {
+ return message != null && message.bodyLength() > 0;
+ }
+
+ /**
+ * Get the length of the message body.
+ *
+ * @return number of bytes of the message body
+ */
+ public long getMessageBodyLength() {
+ long bodyLength = 0;
+ if (message != null) {
+ bodyLength = message.bodyLength();
+ }
+ return bodyLength;
+ }
+
+ /**
+ * Get the realized flatbuf Message.
+ *
+ * @return Message
+ */
+ public Message getMessage() {
+ return message;
+ }
+
+ private int messageLength;
+ private ByteBuffer messageBuffer;
+ private Message message;
+}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageHolder.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageHolder.java
new file mode 100644
index 0000000..975a9af
--- /dev/null
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageHolder.java
@@ -0,0 +1,30 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.vector.ipc.message;
+
+import io.netty.buffer.ArrowBuf;
+import org.apache.arrow.flatbuf.Message;
+
+/**
+ * Class to hold a Message and body when reading messages through a MessageChannelReader.
+ */
+public class MessageHolder {
+ public Message message;
+ public ArrowBuf bodyBuffer;
+}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageReader.java
deleted file mode 100644
index b277c58..0000000
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageReader.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.arrow.vector.ipc.message;
-
-
-import io.netty.buffer.ArrowBuf;
-import org.apache.arrow.flatbuf.Message;
-import org.apache.arrow.memory.BufferAllocator;
-
-import java.io.IOException;
-
-/**
- * Interface for reading a sequence of messages.
- */
-public interface MessageReader {
-
- /**
- * Read the next message in the sequence.
- *
- * @return The read message or null if reached the end of the message sequence
- * @throws IOException
- */
- Message readNextMessage() throws IOException;
-
- /**
- * When a message is followed by a body of data, read that data into an ArrowBuf. This should
- * only be called when a Message has a body length > 0.
- *
- * @param message Read message that is followed by a body of data
- * @param allocator BufferAllocator to allocate memory for body data
- * @return An ArrowBuf containing the body of the message that was read
- * @throws IOException
- */
- ArrowBuf readMessageBody(Message message, BufferAllocator allocator) throws IOException;
-
- /**
- * Return the current number of bytes that have been read.
- *
- * @return number of bytes read
- */
- long bytesRead();
-
- /**
- * Close any resource opened by the message reader, not including message body allocations.
- *
- * @throws IOException
- */
- void close() throws IOException;
-}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java
index 0b409df..7371991 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java
@@ -56,6 +56,12 @@ import io.netty.buffer.ArrowBuf;
*/
public class MessageSerializer {
+ /**
+ * Convert an array of 4 bytes to a little endian i32 value.
+ *
+ * @param bytes byte array with minimum length of 4
+ * @return converted little endian 32-bit integer
+ */
public static int bytesToInt(byte[] bytes) {
return ((bytes[3] & 255) << 24) +
((bytes[2] & 255) << 16) +
@@ -64,11 +70,49 @@ public class MessageSerializer {
}
/**
+ * Convert an integer to a 4 byte array.
+ *
+ * @param value integer value input
+ * @param bytes existing byte array with minimum length of 4 to contain the conversion output
+ */
+ public static void intToBytes(int value, byte[] bytes) {
+ bytes[3] = (byte) (value >>> 24);
+ bytes[2] = (byte) (value >>> 16);
+ bytes[1] = (byte) (value >>> 8);
+ bytes[0] = (byte) (value >>> 0);
+ }
+
+ /**
+ * Aligns the message to 8 byte boundary and adjusts messageLength accordingly, then writes
+ * the message length prefix and message buffer to the Channel.
+ *
+ * @param out Output Channel
+ * @param messageLength Number of bytes in the message buffer, written as little Endian prefix
+ * @param messageBuffer Message buffer to be written
+ * @return Number of bytes written
+ * @return
+ * @throws IOException
+ */
+ public static int writeMessageBufferAligned(WriteChannel out, int messageLength, ByteBuffer messageBuffer) throws IOException {
+
+ // ensure that message aligns to 8 byte padding - 4 bytes for size, then message body
+ if ((messageLength + 4) % 8 != 0) {
+ messageLength += 8 - (messageLength + 4) % 8;
+ }
+ out.writeIntLittleEndian(messageLength);
+ out.write(messageBuffer);
+ out.align();
+
+ // any bytes written are already captured by our size modification above
+ return messageLength + 4;
+ }
+
+ /**
* Serialize a schema object.
*
* @param out where to write the schema
* @param schema the object to serialize to out
- * @return the resulting size of the serialized schema
+ * @return the number of bytes written
* @throws IOException if something went wrong
*/
public static long serialize(WriteChannel out, Schema schema) throws IOException {
@@ -79,49 +123,41 @@ public class MessageSerializer {
int schemaOffset = schema.getSchema(builder);
ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.Schema, schemaOffset, 0);
- int size = serializedMessage.remaining();
- // ensure that message aligns to 8 byte padding - 4 bytes for size, then message body
- if ((size + 4) % 8 != 0) {
- size += 8 - (size + 4) % 8;
- }
-
- out.writeIntLittleEndian(size);
- out.write(serializedMessage);
- out.align(); // any bytes written are already captured by our size modification above
+ int messageLength = serializedMessage.remaining();
- assert (size + 4) % 8 == 0;
- return size + 4;
+ int bytesWritten = writeMessageBufferAligned(out, messageLength, serializedMessage);
+ assert bytesWritten % 8 == 0;
+ return bytesWritten;
}
/**
- * Deserializes a schema object. Format is from serialize().
+ * Deserializes an Arrow Schema object from a schema message. Format is from serialize().
*
- * @param reader the reader interface to deserialize from
- * @return the deserialized object
- * @throws IOException if something went wrong
+ * @param schemaMessage a Message of type MessageHeader.Schema
+ * @return the deserialized Arrow Schema
*/
- public static Schema deserializeSchema(MessageReader reader) throws IOException {
- Message message = reader.readNextMessage();
- if (message == null) {
- throw new IOException("Unexpected end of input. Missing schema.");
- }
- if (message.headerType() != MessageHeader.Schema) {
- throw new IOException("Expected schema but header was " + message.headerType());
- }
-
+ public static Schema deserializeSchema(Message schemaMessage) {
return Schema.convertSchema((org.apache.arrow.flatbuf.Schema)
- message.header(new org.apache.arrow.flatbuf.Schema()));
+ schemaMessage.header(new org.apache.arrow.flatbuf.Schema()));
}
/**
- * Deserializes a schema object. Format is from serialize().
+ * Deserializes an Arrow Schema read from the input channel. Format is from serialize().
*
* @param in the channel to deserialize from
- * @return the deserialized object
+ * @return the deserialized Arrow Schema
* @throws IOException if something went wrong
*/
public static Schema deserializeSchema(ReadChannel in) throws IOException {
- return deserializeSchema(new MessageChannelReader(in));
+ MessageChannelResult result = readMessage(in);
+ if (!result.hasMessage()) {
+ throw new IOException("Unexpected end of input when reading Schema");
+ }
+ if (result.getMessage().headerType() != MessageHeader.Schema) {
+ throw new IOException("Expected schema but header was " + result.getMessage().headerType());
+ }
+
+ return deserializeSchema(result.getMessage());
}
/**
@@ -165,6 +201,14 @@ public class MessageSerializer {
return new ArrowBlock(start, metadataLength + 4, bufferLength);
}
+ /**
+ * Write the Arrow buffers of the record batch to the output channel.
+ *
+ * @param out the output channel to write the buffers to
+ * @param batch an ArrowRecordBatch containing buffers to be written
+ * @return the number of bytes written
+ * @throws IOException
+ */
public static long writeBatchBuffers(WriteChannel out, ArrowRecordBatch batch) throws IOException {
long bufferStart = out.getCurrentPosition();
List<ArrowBuf> buffers = batch.getBuffers();
@@ -188,31 +232,49 @@ public class MessageSerializer {
}
/**
- * Deserializes a RecordBatch.
+ * Deserializes an ArrowRecordBatch from a record batch message and data in an ArrowBuf.
*
- * @param reader the reader interface to deserialize from
- * @param message the object to deserialize to
- * @param alloc to allocate buffers
- * @return the deserialized object
+ * @param recordBatchMessage a Message of type MessageHeader.RecordBatch
+ * @param bodyBuffer Arrow buffer containing the RecordBatch data
+ * @return the deserialized ArrowRecordBatch
* @throws IOException if something went wrong
*/
- public static ArrowRecordBatch deserializeRecordBatch(MessageReader reader, Message message, BufferAllocator alloc)
+ public static ArrowRecordBatch deserializeRecordBatch(Message recordBatchMessage, ArrowBuf bodyBuffer)
throws IOException {
- RecordBatch recordBatchFB = (RecordBatch) message.header(new RecordBatch());
+ RecordBatch recordBatchFB = (RecordBatch) recordBatchMessage.header(new RecordBatch());
+ return deserializeRecordBatch(recordBatchFB, bodyBuffer);
+ }
- // Now read the record batch body
- ArrowBuf buffer = reader.readMessageBody(message, alloc);
- return deserializeRecordBatch(recordBatchFB, buffer);
+ /**
+ * Deserializes an ArrowRecordBatch read from the input channel. This uses the given allocator
+ * to create an ArrowBuf for the batch body data.
+ *
+ * @param in Channel to read a RecordBatch message and data from
+ * @param allocator BufferAllocator to allocate an Arrow buffer to read message body data
+ * @return the deserialized ArrowRecordBatch
+ * @throws IOException
+ */
+ public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, BufferAllocator allocator) throws IOException {
+ MessageChannelResult result = readMessage(in);
+ if (!result.hasMessage()) {
+ throw new IOException("Unexpected end of input when reading a RecordBatch");
+ }
+ if (result.getMessage().headerType() != MessageHeader.RecordBatch) {
+ throw new IOException("Expected RecordBatch but header was " + result.getMessage().headerType());
+ }
+ int bodyLength = (int) result.getMessageBodyLength();
+ ArrowBuf bodyBuffer = readMessageBody(in, bodyLength, allocator);
+ return deserializeRecordBatch(result.getMessage(), bodyBuffer);
}
/**
- * Deserializes a RecordBatch knowing the size of the entire message up front. This
+ * Deserializes an ArrowRecordBatch knowing the size of the entire message up front. This
* minimizes the number of reads to the underlying stream.
*
* @param in the channel to deserialize from
* @param block the object to deserialize to
* @param alloc to allocate buffers
- * @return the deserialized object
+ * @return the deserialized ArrowRecordBatch
* @throws IOException if something went wrong
*/
public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, ArrowBlock block,
@@ -243,7 +305,7 @@ public class MessageSerializer {
}
/**
- * Deserializes a record batch given the Flatbuffer metadata and in-memory body.
+ * Deserializes an ArrowRecordBatch given the Flatbuffer metadata and in-memory body.
*
* @param recordBatchFB Deserialized FlatBuffer record batch
* @param body Read body of the record batch
@@ -320,33 +382,50 @@ public class MessageSerializer {
}
/**
- * Deserializes a DictionaryBatch.
+ * Deserializes an ArrowDictionaryBatch from a dictionary batch Message and data in an ArrowBuf.
*
- * @param reader where to read from
- * @param message the message message metadata to deserialize
- * @param alloc the allocator for new buffers
- * @return the corresponding dictionary batch
+ * @param message a message of type MessageHeader.DictionaryBatch
+ * @param bodyBuffer Arrow buffer containing the DictionaryBatch data
+ * of type MessageHeader.DictionaryBatch
+ * @return the deserialized ArrowDictionaryBatch
* @throws IOException if something went wrong
*/
- public static ArrowDictionaryBatch deserializeDictionaryBatch(MessageReader reader,
- Message message,
- BufferAllocator alloc) throws IOException {
+ public static ArrowDictionaryBatch deserializeDictionaryBatch(Message message, ArrowBuf bodyBuffer) throws IOException {
DictionaryBatch dictionaryBatchFB = (DictionaryBatch) message.header(new DictionaryBatch());
-
- // Now read the record batch body
- ArrowBuf body = reader.readMessageBody(message, alloc);
- ArrowRecordBatch recordBatch = deserializeRecordBatch(dictionaryBatchFB.data(), body);
+ ArrowRecordBatch recordBatch = deserializeRecordBatch(dictionaryBatchFB.data(), bodyBuffer);
return new ArrowDictionaryBatch(dictionaryBatchFB.id(), recordBatch);
}
/**
+ * Deserializes an ArrowDictionaryBatch read from the input channel. This uses the given allocator
+ * to create an ArrowBuf for the batch body data.
+ *
+ * @param in Channel to read a DictionaryBatch message and data from
+ * @param allocator BufferAllocator to allocate an Arrow buffer to read message body data
+ * @return the deserialized ArrowDictionaryBatch
+ * @throws IOException
+ */
+ public static ArrowDictionaryBatch deserializeDictionaryBatch(ReadChannel in, BufferAllocator allocator) throws IOException {
+ MessageChannelResult result = readMessage(in);
+ if (!result.hasMessage()) {
+ throw new IOException("Unexpected end of input when reading a DictionaryBatch");
+ }
+ if (result.getMessage().headerType() != MessageHeader.DictionaryBatch) {
+ throw new IOException("Expected DictionaryBatch but header was " + result.getMessage().headerType());
+ }
+ int bodyLength = (int) result.getMessageBodyLength();
+ ArrowBuf bodyBuffer = readMessageBody(in, bodyLength, allocator);
+ return deserializeDictionaryBatch(result.getMessage(), bodyBuffer);
+ }
+
+ /**
* Deserializes a DictionaryBatch knowing the size of the entire message up front. This
* minimizes the number of reads to the underlying stream.
*
* @param in where to read from
* @param block block metadata for deserializing
* @param alloc to allocate new buffers
- * @return the corresponding dictionary
+ * @return the deserialized ArrowDictionaryBatch
* @throws IOException if something went wrong
*/
public static ArrowDictionaryBatch deserializeDictionaryBatch(ReadChannel in,
@@ -381,30 +460,29 @@ public class MessageSerializer {
/**
* Deserialize a message that is either an ArrowDictionaryBatch or ArrowRecordBatch.
*
- * @param reader Interface to read messages from
- * @param alloc Allocator for message data
+ * @param reader MessageChannelReader to read a sequence of messages from a ReadChannel
* @return The deserialized record batch
* @throws IOException if the message is not an ArrowDictionaryBatch or ArrowRecordBatch
*/
- public static ArrowMessage deserializeMessageBatch(MessageReader reader, BufferAllocator alloc) throws IOException {
- Message message = reader.readNextMessage();
- if (message == null) {
+ public static ArrowMessage deserializeMessageBatch(MessageChannelReader reader) throws IOException {
+ MessageHolder holder = new MessageHolder();
+ if (!reader.readNext(holder)) {
return null;
- } else if (message.bodyLength() > Integer.MAX_VALUE) {
+ } else if (holder.message.bodyLength() > Integer.MAX_VALUE) {
throw new IOException("Cannot currently deserialize record batches over 2GB");
}
- if (message.version() != MetadataVersion.V4) {
+ if (holder.message.version() != MetadataVersion.V4) {
throw new IOException("Received metadata with an incompatible version number");
}
- switch (message.headerType()) {
+ switch (holder.message.headerType()) {
case MessageHeader.RecordBatch:
- return deserializeRecordBatch(reader, message, alloc);
+ return deserializeRecordBatch(holder.message, holder.bodyBuffer);
case MessageHeader.DictionaryBatch:
- return deserializeDictionaryBatch(reader, message, alloc);
+ return deserializeDictionaryBatch(holder.message, holder.bodyBuffer);
default:
- throw new IOException("Unexpected message header type " + message.headerType());
+ throw new IOException("Unexpected message header type " + holder.message.headerType());
}
}
@@ -417,7 +495,7 @@ public class MessageSerializer {
* @throws IOException if the message is not an ArrowDictionaryBatch or ArrowRecordBatch
*/
public static ArrowMessage deserializeMessageBatch(ReadChannel in, BufferAllocator alloc) throws IOException {
- return deserializeMessageBatch(new MessageChannelReader(in), alloc);
+ return deserializeMessageBatch(new MessageChannelReader(in, alloc));
}
/**
@@ -440,24 +518,58 @@ public class MessageSerializer {
return builder.dataBuffer();
}
- private static Message deserializeMessage(ReadChannel in) throws IOException {
+ /**
+ * Read a Message from the in channel and return a MessageResult object that contains the
+ * Message, raw buffer containing the read Message, and length of the Message in bytes. If
+ * the end-of-stream has been reached, MessageResult.hasMessage() will return false.
+ *
+ * @param in ReadChannel to read messages from
+ * @return MessageResult with Message and message information
+ * @throws IOException
+ */
+ public static MessageChannelResult readMessage(ReadChannel in) throws IOException {
+ int messageLength = 0;
+ ByteBuffer messageBuffer = null;
+ Message message = null;
+
// Read the message size. There is an i32 little endian prefix.
ByteBuffer buffer = ByteBuffer.allocate(4);
- if (in.readFully(buffer) != 4) {
- return null;
- }
- int messageLength = bytesToInt(buffer.array());
- if (messageLength == 0) {
- return null;
+ if (in.readFully(buffer) == 4) {
+ messageLength = MessageSerializer.bytesToInt(buffer.array());
+
+ // Length of 0 indicates end of stream
+ if (messageLength != 0) {
+
+ // Read the message into the buffer.
+ messageBuffer = ByteBuffer.allocate(messageLength);
+ if (in.readFully(messageBuffer) != messageLength) {
+ throw new IOException(
+ "Unexpected end of stream trying to read message.");
+ }
+ messageBuffer.rewind();
+
+ // Load the message.
+ message = Message.getRootAsMessage(messageBuffer);
+ }
}
- buffer = ByteBuffer.allocate(messageLength);
- if (in.readFully(buffer) != messageLength) {
- throw new IOException(
- "Unexpected end of stream trying to read message.");
- }
- buffer.rewind();
+ return new MessageChannelResult(messageLength, messageBuffer, message);
+ }
- return Message.getRootAsMessage(buffer);
+ /**
+ * Read a Message body from the in channel into an ArrowBuf.
+ *
+ * @param in ReadChannel to read message body from
+ * @param bodyLength Length in bytes of the message body to read
+ * @param allocator Allocate the ArrowBuf to contain message body data
+ * @return an ArrowBuf containing the message body data
+ * @throws IOException
+ */
+ public static ArrowBuf readMessageBody(ReadChannel in, int bodyLength, BufferAllocator allocator) throws IOException {
+ ArrowBuf bodyBuffer = allocator.buffer(bodyLength);
+ if (in.readFully(bodyBuffer, bodyLength) != bodyLength) {
+ throw new IOException("Unexpected end of input trying to read batch.");
+ }
+ return bodyBuffer;
}
}
diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java
index 064f500..f677c3d 100644
--- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java
+++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java
@@ -26,6 +26,8 @@ import static org.junit.Assert.assertTrue;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.nio.channels.Channels;
import java.util.Collections;
import java.util.List;
@@ -61,6 +63,57 @@ public class MessageSerializerTest {
return bytes;
}
+ private int intToByteRoundtrip(int v, byte[] bytes) {
+ MessageSerializer.intToBytes(v, bytes);
+ return MessageSerializer.bytesToInt(bytes);
+ }
+
+ @Test
+ public void testIntToBytes() {
+ byte[] bytes = new byte[4];
+ int[] values = new int[] {1, 15, 1 << 8, 1 << 16, 1 << 32, Integer.MAX_VALUE};
+ for (int v: values) {
+ assertEquals(intToByteRoundtrip(v, bytes), v);
+ }
+ }
+
+ @Test
+ public void testWriteMessageBufferAligned() throws IOException {
+ ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ WriteChannel out = new WriteChannel(Channels.newChannel(outputStream));
+
+ // This is not a valid Arrow Message, only to test writing and alignment
+ ByteBuffer buffer = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN);
+ buffer.putInt(1);
+ buffer.putInt(2);
+ buffer.flip();
+
+ int bytesWritten = MessageSerializer.writeMessageBufferAligned(out, 8, buffer);
+ assertEquals(16, bytesWritten);
+
+ buffer.rewind();
+ buffer.putInt(3);
+ buffer.flip();
+ bytesWritten = MessageSerializer.writeMessageBufferAligned(out, 4, buffer);
+ assertEquals(8, bytesWritten);
+
+ ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray());
+ ReadChannel in = new ReadChannel(Channels.newChannel(inputStream));
+ ByteBuffer result = ByteBuffer.allocate(32).order(ByteOrder.LITTLE_ENDIAN);
+ in.readFully(result);
+ result.rewind();
+
+ // First message size, 2 int values, 4 bytes of zero padding
+ assertEquals(12, result.getInt());
+ assertEquals(1, result.getInt());
+ assertEquals(2, result.getInt());
+ assertEquals(0, result.getInt());
+
+ // Second message size and 1 int value
+ assertEquals(4, result.getInt());
+ assertEquals(3, result.getInt());
+ }
+
@Test
public void testSchemaMessageSerialization() throws IOException {
Schema schema = testSchema();
diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java
index 431ebf0..f1f5c00 100644
--- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java
+++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java
@@ -60,6 +60,32 @@ public class TestArrowStream extends BaseFileTest {
}
@Test
+ public void testStreamZeroLengthBatch() throws IOException {
+ ByteArrayOutputStream os = new ByteArrayOutputStream();
+
+ try (IntVector vector = new IntVector("foo", allocator);) {
+ Schema schema = new Schema(Collections.singletonList(vector.getField()), null);
+ try (VectorSchemaRoot root = new VectorSchemaRoot(schema, Collections.singletonList(vector), vector.getValueCount());
+ ArrowStreamWriter writer = new ArrowStreamWriter(root, null, Channels.newChannel(os));) {
+ vector.setValueCount(0);
+ root.setRowCount(0);
+ writer.writeBatch();
+ writer.end();
+ }
+ }
+
+ ByteArrayInputStream in = new ByteArrayInputStream(os.toByteArray());
+
+ try (ArrowStreamReader reader = new ArrowStreamReader(in, allocator);) {
+ VectorSchemaRoot root = reader.getVectorSchemaRoot();
+ IntVector vector = (IntVector) root.getFieldVectors().get(0);
+ reader.loadNextBatch();
+ assertEquals(vector.getValueCount(), 0);
+ assertEquals(root.getRowCount(), 0);
+ }
+ }
+
+ @Test
public void testReadWrite() throws IOException {
Schema schema = MessageSerializerTest.testSchema();
try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {