This is an automated email from the ASF dual-hosted git repository. wesm pushed a commit to branch ARROW-6313-flatbuffer-alignment in repository https://gitbox.apache.org/repos/asf/arrow.git
commit 0352456387e0d02036029de4fbb6d49324eb779e Author: tianchen <niki...@alibaba-inc.com> AuthorDate: Fri Sep 6 20:36:38 2019 -0700 ARROW-6315: [Java] Make change to ensure flatbuffer reads are aligned Implements the IPC message format alignment changes for [ARROW-6315](https://issues.apache.org/jira/browse/ARROW-6315). i. MessageReader can read messages with the old alignment ii. ArrowWriter could choose produces messages with the new alignment or the old format. Closes #5229 from tianchen92/ARROW-align-java and squashes the following commits: 1eb71d27c <Bryan Cutler> ARROW-6461: Prevent EchoServer from closing the client socket after writing cd4fd050e <tianchen> fix small bugs 9a690e47d <tianchen> fix comments and styles 5ee858c56 <tianchen> Make change to ensure flatbuffer reads are aligned Lead-authored-by: tianchen <niki...@alibaba-inc.com> Co-authored-by: Bryan Cutler <cutl...@gmail.com> Signed-off-by: Micah Kornfield <emkornfi...@gmail.com> --- .../org/apache/arrow/tools/EchoServerTest.java | 2 +- .../apache/arrow/vector/ipc/ArrowFileWriter.java | 13 ++- .../apache/arrow/vector/ipc/ArrowStreamWriter.java | 22 ++++- .../org/apache/arrow/vector/ipc/ArrowWriter.java | 17 +++- .../org/apache/arrow/vector/ipc/WriteChannel.java | 9 ++ .../apache/arrow/vector/ipc/message/IpcOption.java | 28 ++++++ .../vector/ipc/message/MessageSerializer.java | 108 ++++++++++++++++++--- .../arrow/vector/ipc/MessageSerializerTest.java | 14 +-- .../arrow/vector/ipc/TestArrowReaderWriter.java | 52 +++++++++- .../apache/arrow/vector/ipc/TestArrowStream.java | 2 +- .../arrow/vector/ipc/TestArrowStreamPipe.java | 2 +- 11 files changed, 231 insertions(+), 38 deletions(-) diff --git a/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java b/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java index 219926a..bfb136c 100644 --- a/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java +++ b/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java @@ -127,7 +127,7 @@ public class EchoServerTest { } Assert.assertFalse(reader.loadNextBatch()); assertEquals(0, reader.getVectorSchemaRoot().getRowCount()); - assertEquals(reader.bytesRead(), writer.bytesWritten()); + assertEquals(reader.bytesRead() + 4, writer.bytesWritten()); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java index 395a617..936ab6d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java @@ -29,6 +29,7 @@ import org.apache.arrow.vector.ipc.message.ArrowBlock; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.ArrowFooter; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.IpcOption; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,6 +48,11 @@ public class ArrowFileWriter extends ArrowWriter { super(root, provider, out); } + public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, + IpcOption option) { + super(root, provider, out, option); + } + @Override protected void startInternal(WriteChannel out) throws IOException { ArrowMagic.writeMagic(out, true); @@ -68,7 +74,12 @@ public class ArrowFileWriter extends ArrowWriter { @Override protected void endInternal(WriteChannel out) throws IOException { - out.writeIntLittleEndian(0); + if (option.write_legacy_ipc_format) { + out.writeIntLittleEndian(0); + } else { + out.writeLongLittleEndian(0); + } + long footerStart = out.getCurrentPosition(); out.write(new ArrowFooter(schema, dictionaryBlocks, recordBlocks), false); int footerLength = (int) (out.getCurrentPosition() - footerStart); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java index ec0f42e..e74323b 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java @@ -24,6 +24,7 @@ import java.nio.channels.WritableByteChannel; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.IpcOption; /** * Writer for the Arrow stream format to send ArrowRecordBatches over a WriteChannel. @@ -44,14 +45,23 @@ public class ArrowStreamWriter extends ArrowWriter { /** * Construct an ArrowStreamWriter with an optional DictionaryProvider for the WritableByteChannel. + */ + public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { + this(root, provider, out, new IpcOption()); + } + + /** + * Construct an ArrowStreamWriter with an optional DictionaryProvider for the WritableByteChannel. * * @param root Existing VectorSchemaRoot with vectors to be written. * @param provider DictionaryProvider for any vectors that are dictionary encoded. * (Optional, can be null) + * @param option IPC write options * @param out WritableByteChannel for writing. */ - public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { - super(root, provider, out); + public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, + IpcOption option) { + super(root, provider, out, option); } /** @@ -60,8 +70,12 @@ public class ArrowStreamWriter extends ArrowWriter { * @param out Open WriteChannel with an active Arrow stream. * @throws IOException on error */ - public static void writeEndOfStream(WriteChannel out) throws IOException { - out.writeIntLittleEndian(0); + public void writeEndOfStream(WriteChannel out) throws IOException { + if (option.write_legacy_ipc_format) { + out.writeIntLittleEndian(0); + } else { + out.writeLongLittleEndian(0); + } } @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java index 6366f2f..52ab3de 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java @@ -33,6 +33,7 @@ import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowBlock; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; @@ -59,16 +60,24 @@ public abstract class ArrowWriter implements AutoCloseable { private boolean dictWritten = false; + protected IpcOption option; + + protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { + this (root, provider, out, new IpcOption()); + } + /** * Note: fields are not closed when the writer is closed. * * @param root the vectors to write to the output * @param provider where to find the dictionaries * @param out the output where to write + * @param option IPC write options */ - protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { + protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, IpcOption option) { this.unloader = new VectorUnloader(root); this.out = new WriteChannel(out); + this.option = option; List<Field> fields = new ArrayList<>(root.getSchema().getFields().size()); Set<Long> dictionaryIdsUsed = new HashSet<>(); @@ -112,14 +121,14 @@ public abstract class ArrowWriter implements AutoCloseable { } protected ArrowBlock writeDictionaryBatch(ArrowDictionaryBatch batch) throws IOException { - ArrowBlock block = MessageSerializer.serialize(out, batch); + ArrowBlock block = MessageSerializer.serialize(out, batch, option); LOGGER.debug("DictionaryRecordBatch at {}, metadata: {}, body: {}", block.getOffset(), block.getMetadataLength(), block.getBodyLength()); return block; } protected ArrowBlock writeRecordBatch(ArrowRecordBatch batch) throws IOException { - ArrowBlock block = MessageSerializer.serialize(out, batch); + ArrowBlock block = MessageSerializer.serialize(out, batch, option); LOGGER.debug("RecordBatch at {}, metadata: {}, body: {}", block.getOffset(), block.getMetadataLength(), block.getBodyLength()); return block; @@ -140,7 +149,7 @@ public abstract class ArrowWriter implements AutoCloseable { startInternal(out); // write the schema - for file formats this is duplicated in the footer, but matches // the streaming format - MessageSerializer.serialize(out, schema); + MessageSerializer.serialize(out, schema, option); } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/WriteChannel.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/WriteChannel.java index eef36d3..2d36c93 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/WriteChannel.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/WriteChannel.java @@ -102,6 +102,15 @@ public class WriteChannel implements AutoCloseable { } /** + * Writes <code>v</code> in little-endian format to the underlying channel. + */ + public long writeLongLittleEndian(long v) throws IOException { + byte[] outBuffer = new byte[8]; + MessageSerializer.longToBytes(v, outBuffer); + return write(outBuffer); + } + + /** * Writes the buffer to the underlying channel. */ public void write(ArrowBuf buffer) throws IOException { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java new file mode 100644 index 0000000..81a0603 --- /dev/null +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/IpcOption.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.ipc.message; + +/** + * IPC options, now only use for write. + */ +public class IpcOption { + + // Write the pre-0.15.0 encapsulated IPC message format + // consisting of a 4-byte prefix instead of 8 byte + public boolean write_legacy_ipc_format = false; +} diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java index 4016802..34ea077 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/message/MessageSerializer.java @@ -56,6 +56,9 @@ import io.netty.buffer.ArrowBuf; */ public class MessageSerializer { + // This 0xFFFFFFFF value is the first 4 bytes of a valid IPC message + public static final int IPC_CONTINUATION_TOKEN = -1; + /** * Convert an array of 4 bytes to a little endian i32 value. * @@ -83,6 +86,28 @@ public class MessageSerializer { } /** + * Convert a long to a 8 byte array. + * + * @param value long value input + * @param bytes existing byte array with minimum length of 8 to contain the conversion output + */ + public static void longToBytes(long value, byte[] bytes) { + bytes[7] = (byte) (value >>> 56); + bytes[6] = (byte) (value >>> 48); + bytes[5] = (byte) (value >>> 40); + bytes[4] = (byte) (value >>> 32); + bytes[3] = (byte) (value >>> 24); + bytes[2] = (byte) (value >>> 16); + bytes[1] = (byte) (value >>> 8); + bytes[0] = (byte) (value); + } + + public static int writeMessageBuffer(WriteChannel out, int messageLength, ByteBuffer messageBuffer) + throws IOException { + return writeMessageBuffer(out, messageLength, messageBuffer, new IpcOption()); + } + + /** * Write the serialized Message metadata, prefixed by the length, to the output Channel. This * ensures that it aligns to an 8 byte boundary and will adjust the message length to include * any padding used for alignment. @@ -91,22 +116,36 @@ public class MessageSerializer { * @param messageLength Number of bytes in the message buffer, written as little Endian prefix * @param messageBuffer Message metadata buffer to be written, this does not include any * message body data which should be subsequently written to the Channel + * @param option IPC write options * @return Number of bytes written * @throws IOException on error */ - public static int writeMessageBuffer(WriteChannel out, int messageLength, ByteBuffer messageBuffer) + public static int writeMessageBuffer(WriteChannel out, int messageLength, ByteBuffer messageBuffer, IpcOption option) throws IOException { - // ensure that message aligns to 8 byte padding - 4 bytes for size, then message body - if ((messageLength + 4) % 8 != 0) { - messageLength += 8 - (messageLength + 4) % 8; + // if write the pre-0.15.0 encapsulated IPC message format consisting of a 4-byte prefix instead of 8 byte + int prefixSize = option.write_legacy_ipc_format ? 4 : 8; + + // ensure that message aligns to 8 byte padding - prefix_size bytes, then message body + if ((messageLength + prefixSize ) % 8 != 0) { + messageLength += 8 - (messageLength + prefixSize) % 8; + } + if (!option.write_legacy_ipc_format) { + out.writeIntLittleEndian(IPC_CONTINUATION_TOKEN); } out.writeIntLittleEndian(messageLength); out.write(messageBuffer); out.align(); // any bytes written are already captured by our size modification above - return messageLength + 4; + return messageLength + prefixSize; + } + + /** + * Serialize a schema object. + */ + public static long serialize(WriteChannel out, Schema schema) throws IOException { + return serialize(out, schema, new IpcOption()); } /** @@ -117,7 +156,7 @@ public class MessageSerializer { * @return the number of bytes written * @throws IOException if something went wrong */ - public static long serialize(WriteChannel out, Schema schema) throws IOException { + public static long serialize(WriteChannel out, Schema schema, IpcOption option) throws IOException { long start = out.getCurrentPosition(); assert start % 8 == 0; @@ -125,7 +164,7 @@ public class MessageSerializer { int messageLength = serializedMessage.remaining(); - int bytesWritten = writeMessageBuffer(out, messageLength, serializedMessage); + int bytesWritten = writeMessageBuffer(out, messageLength, serializedMessage, option); assert bytesWritten % 8 == 0; return bytesWritten; } @@ -182,13 +221,20 @@ public class MessageSerializer { /** * Serializes an ArrowRecordBatch. Returns the offset and length of the written batch. + */ + public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch) throws IOException { + return serialize(out, batch, new IpcOption()); + } + + /** + * Serializes an ArrowRecordBatch. Returns the offset and length of the written batch. * * @param out where to write the batch * @param batch the object to serialize to out * @return the serialized block metadata * @throws IOException if something went wrong */ - public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch) throws IOException { + public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch, IpcOption option) throws IOException { long start = out.getCurrentPosition(); int bodyLength = batch.computeBodyLength(); @@ -198,8 +244,14 @@ public class MessageSerializer { int metadataLength = serializedMessage.remaining(); + int prefixSize = 4; + if (!option.write_legacy_ipc_format) { + out.writeIntLittleEndian(IPC_CONTINUATION_TOKEN); + prefixSize = 8; + } + // calculate alignment bytes so that metadata length points to the correct location after alignment - int padding = (int) ((start + metadataLength + 4) % 8); + int padding = (int) ((start + metadataLength + prefixSize) % 8); if (padding != 0) { metadataLength += (8 - padding); } @@ -214,7 +266,7 @@ public class MessageSerializer { assert bufferLength % 8 == 0; // Metadata size in the Block account for the size prefix - return new ArrowBlock(start, metadataLength + 4, bufferLength); + return new ArrowBlock(start, metadataLength + prefixSize, bufferLength); } /** @@ -305,7 +357,7 @@ public class MessageSerializer { */ public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, ArrowBlock block, BufferAllocator alloc) throws IOException { - // Metadata length contains integer prefix plus byte padding + // Metadata length contains prefix_size bytes plus byte padding long totalLen = block.getMetadataLength() + block.getBodyLength(); if (totalLen > Integer.MAX_VALUE) { @@ -317,7 +369,9 @@ public class MessageSerializer { throw new IOException("Unexpected end of input trying to read batch."); } - ArrowBuf metadataBuffer = buffer.slice(4, block.getMetadataLength() - 4); + int prefixSize = buffer.getInt(0) == IPC_CONTINUATION_TOKEN ? 8 : 4; + + ArrowBuf metadataBuffer = buffer.slice(prefixSize, block.getMetadataLength() - prefixSize); Message messageFB = Message.getRootAsMessage(metadataBuffer.nioBuffer().asReadOnlyBuffer()); @@ -375,15 +429,21 @@ public class MessageSerializer { return deserializeRecordBatch(serializedMessage.getMessage(), underlying); } + public static ArrowBlock serialize(WriteChannel out, ArrowDictionaryBatch batch) throws IOException { + return serialize(out, batch, new IpcOption()); + } + /** * Serializes a dictionary ArrowRecordBatch. Returns the offset and length of the written batch. * * @param out where to serialize * @param batch the batch to serialize + * @param option options for IPC * @return the metadata of the serialized block * @throws IOException if something went wrong */ - public static ArrowBlock serialize(WriteChannel out, ArrowDictionaryBatch batch) throws IOException { + public static ArrowBlock serialize(WriteChannel out, ArrowDictionaryBatch batch, IpcOption option) + throws IOException { long start = out.getCurrentPosition(); int bodyLength = batch.computeBodyLength(); assert bodyLength % 8 == 0; @@ -392,8 +452,14 @@ public class MessageSerializer { int metadataLength = serializedMessage.remaining(); + int prefixSize = 4; + if (!option.write_legacy_ipc_format) { + out.writeIntLittleEndian(IPC_CONTINUATION_TOKEN); + prefixSize = 8; + } + // calculate alignment bytes so that metadata length points to the correct location after alignment - int padding = (int) ((start + metadataLength + 4) % 8); + int padding = (int) ((start + metadataLength + prefixSize) % 8); if (padding != 0) { metadataLength += (8 - padding); } @@ -409,7 +475,7 @@ public class MessageSerializer { assert bufferLength % 8 == 0; // Metadata size in the Block account for the size prefix - return new ArrowBlock(start, metadataLength + 4, bufferLength); + return new ArrowBlock(start, metadataLength + prefixSize, bufferLength); } /** @@ -491,7 +557,9 @@ public class MessageSerializer { throw new IOException("Unexpected end of input trying to read batch."); } - ArrowBuf metadataBuffer = buffer.slice(4, block.getMetadataLength() - 4); + int prefixSize = buffer.getInt(0) == IPC_CONTINUATION_TOKEN ? 8 : 4; + + ArrowBuf metadataBuffer = buffer.slice(prefixSize, block.getMetadataLength() - prefixSize); Message messageFB = Message.getRootAsMessage(metadataBuffer.nioBuffer().asReadOnlyBuffer()); @@ -584,7 +652,15 @@ public class MessageSerializer { // Read the message size. There is an i32 little endian prefix. ByteBuffer buffer = ByteBuffer.allocate(4); if (in.readFully(buffer) == 4) { + int messageLength = MessageSerializer.bytesToInt(buffer.array()); + if (messageLength == IPC_CONTINUATION_TOKEN) { + buffer.clear(); + // ARROW-6313, if the first 4 bytes are continuation message, read the next 4 for the length + if (in.readFully(buffer) == 4) { + messageLength = MessageSerializer.bytesToInt(buffer.array()); + } + } // Length of 0 indicates end of stream if (messageLength != 0) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java index 789da1f..1cbd5bb 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/MessageSerializerTest.java @@ -95,7 +95,7 @@ public class MessageSerializerTest { buffer.putInt(3); buffer.flip(); bytesWritten = MessageSerializer.writeMessageBuffer(out, 4, buffer); - assertEquals(8, bytesWritten); + assertEquals(16, bytesWritten); ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray()); ReadChannel in = new ReadChannel(Channels.newChannel(inputStream)); @@ -103,15 +103,17 @@ public class MessageSerializerTest { in.readFully(result); result.rewind(); - // First message size, 2 int values, 4 bytes of zero padding - assertEquals(12, result.getInt()); + // First message continuation, size, and 2 int values + assertEquals(MessageSerializer.IPC_CONTINUATION_TOKEN, result.getInt()); + assertEquals(8, result.getInt()); assertEquals(1, result.getInt()); assertEquals(2, result.getInt()); - assertEquals(0, result.getInt()); - // Second message size and 1 int value - assertEquals(4, result.getInt()); + // Second message continuation, size, 1 int value and 4 bytes padding + assertEquals(MessageSerializer.IPC_CONTINUATION_TOKEN, result.getInt()); + assertEquals(8, result.getInt()); assertEquals(3, result.getInt()); + assertEquals(0, result.getInt()); } @Test diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java index 5d1f792..58ad669 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java @@ -25,6 +25,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.ByteBuffer; @@ -42,6 +43,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.Collections2; import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.TestUtils; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorLoader; @@ -54,6 +56,7 @@ import org.apache.arrow.vector.ipc.message.ArrowBlock; import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.IpcOption; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; @@ -75,7 +78,7 @@ public class TestArrowReaderWriter { private Dictionary dictionary2; private Schema schema; - private Schema encodedchema; + private Schema encodedSchema; @Before public void init() { @@ -161,7 +164,8 @@ public class TestArrowReaderWriter { // deserialize the buffer. ByteBuffer headerBuffer = ByteBuffer.allocate(recordBatches.get(0).getMetadataLength()); headerBuffer.put(byteArray, (int) recordBatches.get(0).getOffset(), headerBuffer.capacity()); - headerBuffer.position(4); + // new format prefix_size ==8 + headerBuffer.position(8); Message messageFB = Message.getRootAsMessage(headerBuffer); RecordBatch recordBatchFB = (RecordBatch) messageFB.header(new RecordBatch()); assertEquals(2, recordBatchFB.buffersLength()); @@ -335,7 +339,7 @@ public class TestArrowReaderWriter { try (ArrowStreamReader reader = new ArrowStreamReader( new ByteArrayReadableSeekableByteChannel(outStream.toByteArray()), allocator)) { Schema readSchema = reader.getVectorSchemaRoot().getSchema(); - assertEquals(encodedchema, readSchema); + assertEquals(encodedSchema, readSchema); assertEquals(2, reader.getDictionaryVectors().size()); assertTrue(reader.loadNextBatch()); assertTrue(reader.loadNextBatch()); @@ -401,8 +405,48 @@ public class TestArrowReaderWriter { schemaFields.add(DictionaryUtility.toMessageFormat(encodedVectorA2.getField(), provider, new HashSet<>())); schema = new Schema(schemaFields); - encodedchema = new Schema(Arrays.asList(encodedVectorA1.getField(), encodedVectorA2.getField())); + encodedSchema = new Schema(Arrays.asList(encodedVectorA1.getField(), encodedVectorA2.getField())); return batches; } + + @Test + public void testLegacyIpcBackwardsCompatibility() throws Exception { + Schema schema = new Schema(asList(Field.nullable("field", new ArrowType.Int(32, true)))); + IntVector vector = new IntVector("vector", allocator); + final int valueCount = 2; + vector.setValueCount(valueCount); + vector.setSafe(0, 1); + vector.setSafe(1, 2); + ArrowRecordBatch batch = new ArrowRecordBatch(valueCount, asList(new ArrowFieldNode(valueCount, 0)), + asList(vector.getValidityBuffer(), vector.getDataBuffer())); + + ByteArrayOutputStream outStream = new ByteArrayOutputStream(); + WriteChannel out = new WriteChannel(newChannel(outStream)); + + // write legacy ipc format + IpcOption option = new IpcOption(); + option.write_legacy_ipc_format = true; + MessageSerializer.serialize(out, schema, option); + MessageSerializer.serialize(out, batch); + + ReadChannel in = new ReadChannel(newChannel(new ByteArrayInputStream(outStream.toByteArray()))); + Schema readSchema = MessageSerializer.deserializeSchema(in); + assertEquals(schema, readSchema); + ArrowRecordBatch readBatch = MessageSerializer.deserializeRecordBatch(in, allocator); + assertEquals(batch.getLength(), readBatch.getLength()); + assertEquals(batch.computeBodyLength(), readBatch.computeBodyLength()); + + // write ipc format with continuation + option.write_legacy_ipc_format = false; + MessageSerializer.serialize(out, schema, option); + MessageSerializer.serialize(out, batch); + + ReadChannel in2 = new ReadChannel(newChannel(new ByteArrayInputStream(outStream.toByteArray()))); + Schema readSchema2 = MessageSerializer.deserializeSchema(in2); + assertEquals(schema, readSchema2); + ArrowRecordBatch readBatch2 = MessageSerializer.deserializeRecordBatch(in2, allocator); + assertEquals(batch.getLength(), readBatch2.getLength()); + assertEquals(batch.computeBodyLength(), readBatch2.computeBodyLength()); + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java index 92e5276..5d8f5df 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java @@ -117,7 +117,7 @@ public class TestArrowStream extends BaseFileTest { assertTrue(reader.loadNextBatch()); } // TODO figure out why reader isn't getting padding bytes - assertEquals(bytesWritten, reader.bytesRead() + 4); + assertEquals(bytesWritten, reader.bytesRead() + 8); assertFalse(reader.loadNextBatch()); assertEquals(0, reader.getVectorSchemaRoot().getRowCount()); } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStreamPipe.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStreamPipe.java index 422a63f..07f4017 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStreamPipe.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStreamPipe.java @@ -156,6 +156,6 @@ public class TestArrowStreamPipe { writer.join(); assertEquals(NUM_BATCHES, reader.getBatchesRead()); - assertEquals(writer.bytesWritten(), reader.bytesRead()); + assertEquals(writer.bytesWritten(), reader.bytesRead() + 4); } }