Repository: arrow Updated Branches: refs/heads/master b4892fd9f -> 01114d831
ARROW-783: [Java/C++] Fixes for 0-length record batches @StevenMPhillips @nongli @julienledem I found a number of issues in both C++ and Java around the handling of 0-length vectors. It seems that preserving a single inconsequential offset for a length-0 variable length vector can be a bit difficult, so I relaxed a restruction in `loadFieldVectors` about this. Let me know if there's anything concerning about the other changes around EOS signaling Author: Wes McKinney <[email protected]> Closes #505 from wesm/ARROW-783 and squashes the following commits: 28ddcab [Wes McKinney] * Have loadNextBatch return true/false for EOS to accommodate 0-length record batches * Relax n + 1 restruction for 0-length vectors Project: http://git-wip-us.apache.org/repos/asf/arrow/repo Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/01114d83 Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/01114d83 Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/01114d83 Branch: refs/heads/master Commit: 01114d831b1cd0cdb9a7f28958d181dcece2537f Parents: b4892fd Author: Wes McKinney <[email protected]> Authored: Fri Apr 14 16:20:06 2017 -0400 Committer: Wes McKinney <[email protected]> Committed: Fri Apr 14 16:20:06 2017 -0400 ---------------------------------------------------------------------- cpp/src/arrow/loader.cc | 16 +++---------- integration/integration_test.py | 8 +++---- .../org/apache/arrow/tools/FileRoundtrip.java | 4 +--- .../org/apache/arrow/tools/FileToStream.java | 10 +++++--- .../org/apache/arrow/tools/Integration.java | 17 ++++++++----- .../org/apache/arrow/tools/StreamToFile.java | 10 +++++--- .../arrow/tools/ArrowFileTestFixtures.java | 4 +++- .../org/apache/arrow/tools/EchoServerTest.java | 4 ++-- .../codegen/templates/NullableValueVectors.java | 4 +++- .../arrow/vector/file/ArrowFileReader.java | 4 ++-- .../apache/arrow/vector/file/ArrowReader.java | 14 +++++++++-- .../arrow/vector/file/json/JsonFileReader.java | 4 +++- .../apache/arrow/vector/file/TestArrowFile.java | 25 ++++++++++---------- .../arrow/vector/file/TestArrowStream.java | 12 ++++++---- .../arrow/vector/file/TestArrowStreamPipe.java | 9 ++++--- 15 files changed, 82 insertions(+), 63 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/cpp/src/arrow/loader.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/loader.cc b/cpp/src/arrow/loader.cc index f9f6e6f..e4e1ba4 100644 --- a/cpp/src/arrow/loader.cc +++ b/cpp/src/arrow/loader.cc @@ -97,13 +97,8 @@ class ArrayLoader { std::shared_ptr<Buffer> null_bitmap, offsets, values; RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); - if (field_meta.length > 0) { - RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &offsets)); - RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &values)); - } else { - context_->buffer_index += 2; - offsets = values = nullptr; - } + RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &offsets)); + RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &values)); result_ = std::make_shared<CONTAINER>( field_meta.length, offsets, values, null_bitmap, field_meta.null_count); @@ -166,12 +161,7 @@ class ArrayLoader { std::shared_ptr<Buffer> null_bitmap, offsets; RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); - if (field_meta.length > 0) { - RETURN_NOT_OK(GetBuffer(context_->buffer_index, &offsets)); - } else { - offsets = nullptr; - } - ++context_->buffer_index; + RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &offsets)); const int num_children = type.num_children(); if (num_children != 1) { http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/integration/integration_test.py ---------------------------------------------------------------------- diff --git a/integration/integration_test.py b/integration/integration_test.py index 6631dc8..661f5c9 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -593,7 +593,7 @@ def _generate_file(fields, batch_sizes): return JSONFile(schema, batches) -def generate_primitive_case(): +def generate_primitive_case(batch_sizes): types = ['bool', 'int8', 'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64', 'float32', 'float64', 'binary', 'utf8'] @@ -604,7 +604,6 @@ def generate_primitive_case(): fields.append(get_field(type_ + "_nullable", type_, True)) fields.append(get_field(type_ + "_nonnullable", type_, False)) - batch_sizes = [7, 10] return _generate_file(fields, batch_sizes) @@ -648,9 +647,8 @@ def get_generated_json_files(): return file_objs = [ - generate_primitive_case(), - generate_primitive_case(), - generate_primitive_case(), + generate_primitive_case([7, 10]), + generate_primitive_case([0, 0, 0]), generate_datetime_case(), generate_nested_case() ] http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/java/tools/src/main/java/org/apache/arrow/tools/FileRoundtrip.java ---------------------------------------------------------------------- diff --git a/java/tools/src/main/java/org/apache/arrow/tools/FileRoundtrip.java b/java/tools/src/main/java/org/apache/arrow/tools/FileRoundtrip.java index b862192..135d492 100644 --- a/java/tools/src/main/java/org/apache/arrow/tools/FileRoundtrip.java +++ b/java/tools/src/main/java/org/apache/arrow/tools/FileRoundtrip.java @@ -93,9 +93,7 @@ public class FileRoundtrip { fileOutputStream.getChannel())) { arrowWriter.start(); while (true) { - arrowReader.loadNextBatch(); - int loaded = root.getRowCount(); - if (loaded == 0) { + if (!arrowReader.loadNextBatch()) { break; } else { arrowWriter.writeBatch(); http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/java/tools/src/main/java/org/apache/arrow/tools/FileToStream.java ---------------------------------------------------------------------- diff --git a/java/tools/src/main/java/org/apache/arrow/tools/FileToStream.java b/java/tools/src/main/java/org/apache/arrow/tools/FileToStream.java index be404fd..6722b30 100644 --- a/java/tools/src/main/java/org/apache/arrow/tools/FileToStream.java +++ b/java/tools/src/main/java/org/apache/arrow/tools/FileToStream.java @@ -41,12 +41,16 @@ public class FileToStream { try (ArrowFileReader reader = new ArrowFileReader(in.getChannel(), allocator)) { VectorSchemaRoot root = reader.getVectorSchemaRoot(); // load the first batch before instantiating the writer so that we have any dictionaries - reader.loadNextBatch(); + if (!reader.loadNextBatch()) { + throw new IOException("Unable to read first record batch"); + } try (ArrowStreamWriter writer = new ArrowStreamWriter(root, reader, out)) { writer.start(); - while (root.getRowCount() > 0) { + while (true) { writer.writeBatch(); - reader.loadNextBatch(); + if (!reader.loadNextBatch()) { + break; + } } writer.end(); } http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/java/tools/src/main/java/org/apache/arrow/tools/Integration.java ---------------------------------------------------------------------- diff --git a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java b/java/tools/src/main/java/org/apache/arrow/tools/Integration.java index 453693d..e8266d5 100644 --- a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java +++ b/java/tools/src/main/java/org/apache/arrow/tools/Integration.java @@ -126,7 +126,9 @@ public class Integration { .pretty(true))) { writer.start(schema); for (ArrowBlock rbBlock : arrowReader.getRecordBlocks()) { - arrowReader.loadRecordBatch(rbBlock); + if (!arrowReader.loadRecordBatch(rbBlock)) { + throw new IOException("Expected to load record batch"); + } writer.write(root); } } @@ -148,10 +150,8 @@ public class Integration { ArrowFileWriter arrowWriter = new ArrowFileWriter(root, null, fileOutputStream .getChannel())) { arrowWriter.start(); - reader.read(root); - while (root.getRowCount() != 0) { + while (reader.read(root)) { arrowWriter.writeBatch(); - reader.read(root); } arrowWriter.end(); } @@ -179,16 +179,21 @@ public class Integration { List<ArrowBlock> recordBatches = arrowReader.getRecordBlocks(); Iterator<ArrowBlock> iterator = recordBatches.iterator(); VectorSchemaRoot jsonRoot; + int totalBatches = 0; while ((jsonRoot = jsonReader.read()) != null && iterator.hasNext()) { ArrowBlock rbBlock = iterator.next(); - arrowReader.loadRecordBatch(rbBlock); + if (!arrowReader.loadRecordBatch(rbBlock)) { + throw new IOException("Expected to load record batch"); + } Validator.compareVectorSchemaRoot(arrowRoot, jsonRoot); jsonRoot.close(); + totalBatches++; } boolean hasMoreJSON = jsonRoot != null; boolean hasMoreArrow = iterator.hasNext(); if (hasMoreJSON || hasMoreArrow) { - throw new IllegalArgumentException("Unexpected RecordBatches. J:" + hasMoreJSON + " " + throw new IllegalArgumentException("Unexpected RecordBatches. Total: " + totalBatches + + " J:" + hasMoreJSON + " " + "A:" + hasMoreArrow); } } http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/java/tools/src/main/java/org/apache/arrow/tools/StreamToFile.java ---------------------------------------------------------------------- diff --git a/java/tools/src/main/java/org/apache/arrow/tools/StreamToFile.java b/java/tools/src/main/java/org/apache/arrow/tools/StreamToFile.java index 41dfd34..ef1a11f 100644 --- a/java/tools/src/main/java/org/apache/arrow/tools/StreamToFile.java +++ b/java/tools/src/main/java/org/apache/arrow/tools/StreamToFile.java @@ -41,12 +41,16 @@ public class StreamToFile { try (ArrowStreamReader reader = new ArrowStreamReader(in, allocator)) { VectorSchemaRoot root = reader.getVectorSchemaRoot(); // load the first batch before instantiating the writer so that we have any dictionaries - reader.loadNextBatch(); + if (!reader.loadNextBatch()) { + throw new IOException("Unable to read first record batch"); + } try (ArrowFileWriter writer = new ArrowFileWriter(root, reader, Channels.newChannel(out))) { writer.start(); - while (root.getRowCount() > 0) { + while (true) { writer.writeBatch(); - reader.loadNextBatch(); + if (!reader.loadNextBatch()) { + break; + } } writer.end(); } http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/java/tools/src/test/java/org/apache/arrow/tools/ArrowFileTestFixtures.java ---------------------------------------------------------------------- diff --git a/java/tools/src/test/java/org/apache/arrow/tools/ArrowFileTestFixtures.java b/java/tools/src/test/java/org/apache/arrow/tools/ArrowFileTestFixtures.java index 1a38909..34c93ed 100644 --- a/java/tools/src/test/java/org/apache/arrow/tools/ArrowFileTestFixtures.java +++ b/java/tools/src/test/java/org/apache/arrow/tools/ArrowFileTestFixtures.java @@ -67,7 +67,9 @@ public class ArrowFileTestFixtures { VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); Schema schema = root.getSchema(); for (ArrowBlock rbBlock : arrowReader.getRecordBlocks()) { - arrowReader.loadRecordBatch(rbBlock); + if (!arrowReader.loadRecordBatch(rbBlock)) { + throw new IOException("Expected to read record batch"); + } validateContent(COUNT, root); } } http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java ---------------------------------------------------------------------- diff --git a/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java b/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java index 7d07588..7cca339 100644 --- a/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java +++ b/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java @@ -118,7 +118,7 @@ public class EchoServerTest { NullableTinyIntVector readVector = (NullableTinyIntVector) reader.getVectorSchemaRoot() .getFieldVectors().get(0); for (int i = 0; i < batches; i++) { - reader.loadNextBatch(); + Assert.assertTrue(reader.loadNextBatch()); assertEquals(16, reader.getVectorSchemaRoot().getRowCount()); assertEquals(16, readVector.getAccessor().getValueCount()); for (int j = 0; j < 8; j++) { @@ -126,7 +126,7 @@ public class EchoServerTest { assertTrue(readVector.getAccessor().isNull(j + 8)); } } - reader.loadNextBatch(); + Assert.assertFalse(reader.loadNextBatch()); assertEquals(0, reader.getVectorSchemaRoot().getRowCount()); assertEquals(reader.bytesRead(), writer.bytesWritten()); } http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/java/vector/src/main/codegen/templates/NullableValueVectors.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/codegen/templates/NullableValueVectors.java b/java/vector/src/main/codegen/templates/NullableValueVectors.java index a50771a..e5257ce 100644 --- a/java/vector/src/main/codegen/templates/NullableValueVectors.java +++ b/java/vector/src/main/codegen/templates/NullableValueVectors.java @@ -122,7 +122,9 @@ public final class ${className} extends BaseDataValueVector implements <#if type public void loadFieldBuffers(ArrowFieldNode fieldNode, List<ArrowBuf> ownBuffers) { <#if type.major = "VarLen"> // variable width values: truncate offset vector buffer to size (#1) - org.apache.arrow.vector.BaseDataValueVector.truncateBufferBasedOnSize(ownBuffers, 1, values.offsetVector.getBufferSizeFor(fieldNode.getLength() + 1)); + org.apache.arrow.vector.BaseDataValueVector.truncateBufferBasedOnSize(ownBuffers, 1, + values.offsetVector.getBufferSizeFor( + fieldNode.getLength() == 0? 0 : fieldNode.getLength() + 1)); <#else> // fixed width values truncate value vector to size (#1) org.apache.arrow.vector.BaseDataValueVector.truncateBufferBasedOnSize(ownBuffers, 1, values.getBufferSizeFor(fieldNode.getLength())); http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java index 28440a1..f4d6ada 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java @@ -103,14 +103,14 @@ public class ArrowFileReader extends ArrowReader<SeekableReadChannel> { return footer.getRecordBatches(); } - public void loadRecordBatch(ArrowBlock block) throws IOException { + public boolean loadRecordBatch(ArrowBlock block) throws IOException { ensureInitialized(); int blockIndex = footer.getRecordBatches().indexOf(block); if (blockIndex == -1) { throw new IllegalArgumentException("Arrow bock does not exist in record batches: " + block); } currentRecordBatch = blockIndex; - loadNextBatch(); + return loadNextBatch(); } private ArrowDictionaryBatch readDictionaryBatch(SeekableReadChannel in, http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java index 1646fbe..1d33913 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java @@ -89,7 +89,8 @@ public abstract class ArrowReader<T extends ReadChannel> implements DictionaryPr } } - public void loadNextBatch() throws IOException { + // Returns true if a batch was read, false on EOS + public boolean loadNextBatch() throws IOException { ensureInitialized(); // read in all dictionary batches, then stop after our first record batch ArrowMessageVisitor<Boolean> visitor = new ArrowMessageVisitor<Boolean>() { @@ -106,9 +107,18 @@ public abstract class ArrowReader<T extends ReadChannel> implements DictionaryPr }; root.setRowCount(0); ArrowMessage message = readMessage(in, allocator); - while (message != null && message.accepts(visitor)) { + + boolean readBatch = false; + while (message != null) { + if (!message.accepts(visitor)) { + readBatch = true; + break; + } + // else read a dictionary message = readMessage(in, allocator); } + + return readBatch; } public long bytesRead() { return in.bytesRead(); } http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java ---------------------------------------------------------------------- diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java index fde9954..21aa037 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java @@ -94,7 +94,7 @@ public class JsonFileReader implements AutoCloseable { } } - public void read(VectorSchemaRoot root) throws IOException { + public boolean read(VectorSchemaRoot root) throws IOException { JsonToken t = parser.nextToken(); if (t == START_OBJECT) { { @@ -111,8 +111,10 @@ public class JsonFileReader implements AutoCloseable { readToken(END_ARRAY); } readToken(END_OBJECT); + return true; } else if (t == END_ARRAY) { root.setRowCount(0); + return false; } else { throw new IllegalArgumentException("Invalid token: " + t); } http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java ---------------------------------------------------------------------- diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java index a1104ff..11730af 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowFile.java @@ -152,7 +152,7 @@ public class TestArrowFile extends BaseFileTest { VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); - arrowReader.loadNextBatch(); + Assert.assertTrue(arrowReader.loadNextBatch()); Assert.assertEquals(count, root.getRowCount()); validateContent(count, root); } @@ -193,7 +193,7 @@ public class TestArrowFile extends BaseFileTest { VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); - arrowReader.loadNextBatch(); + Assert.assertTrue(arrowReader.loadNextBatch()); Assert.assertEquals(count, root.getRowCount()); validateComplexContent(count, root); } @@ -263,13 +263,12 @@ public class TestArrowFile extends BaseFileTest { int i = 0; for (int n = 0; n < 2; n++) { - arrowReader.loadNextBatch(); + Assert.assertTrue(arrowReader.loadNextBatch()); Assert.assertEquals("RB #" + i, counts[i], root.getRowCount()); validateContent(counts[i], root); ++i; } - arrowReader.loadNextBatch(); - Assert.assertEquals(0, root.getRowCount()); + Assert.assertFalse(arrowReader.loadNextBatch()); } } @@ -294,7 +293,7 @@ public class TestArrowFile extends BaseFileTest { VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); - arrowReader.loadNextBatch(); + Assert.assertTrue(arrowReader.loadNextBatch()); validateUnionData(count, root); } @@ -305,7 +304,7 @@ public class TestArrowFile extends BaseFileTest { VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); - arrowReader.loadNextBatch(); + Assert.assertTrue(arrowReader.loadNextBatch()); validateUnionData(count, root); } } @@ -347,7 +346,7 @@ public class TestArrowFile extends BaseFileTest { VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); - arrowReader.loadNextBatch(); + Assert.assertTrue(arrowReader.loadNextBatch()); validateTinyData(root); } @@ -358,7 +357,7 @@ public class TestArrowFile extends BaseFileTest { VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); - arrowReader.loadNextBatch(); + Assert.assertTrue(arrowReader.loadNextBatch()); validateTinyData(root); } } @@ -433,7 +432,7 @@ public class TestArrowFile extends BaseFileTest { VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); - arrowReader.loadNextBatch(); + Assert.assertTrue(arrowReader.loadNextBatch()); validateFlatDictionary(root.getFieldVectors().get(0), arrowReader); } @@ -444,7 +443,7 @@ public class TestArrowFile extends BaseFileTest { VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); - arrowReader.loadNextBatch(); + Assert.assertTrue(arrowReader.loadNextBatch()); validateFlatDictionary(root.getFieldVectors().get(0), arrowReader); } } @@ -537,7 +536,7 @@ public class TestArrowFile extends BaseFileTest { VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); - arrowReader.loadNextBatch(); + Assert.assertTrue(arrowReader.loadNextBatch()); validateNestedDictionary((ListVector) root.getFieldVectors().get(0), arrowReader); } @@ -548,7 +547,7 @@ public class TestArrowFile extends BaseFileTest { VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); Schema schema = root.getSchema(); LOGGER.debug("reading schema: " + schema); - arrowReader.loadNextBatch(); + Assert.assertTrue(arrowReader.loadNextBatch()); validateNestedDictionary((ListVector) root.getFieldVectors().get(0), arrowReader); } } http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStream.java ---------------------------------------------------------------------- diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStream.java b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStream.java index e7cdf3f..7e9afd3 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStream.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStream.java @@ -19,6 +19,7 @@ package org.apache.arrow.vector.file; import static java.util.Arrays.asList; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import java.io.ByteArrayInputStream; @@ -36,6 +37,7 @@ import org.apache.arrow.vector.stream.ArrowStreamReader; import org.apache.arrow.vector.stream.ArrowStreamWriter; import org.apache.arrow.vector.stream.MessageSerializerTest; import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Assert; import org.junit.Test; public class TestArrowStream extends BaseFileTest { @@ -52,10 +54,10 @@ public class TestArrowStream extends BaseFileTest { ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray()); try (ArrowStreamReader reader = new ArrowStreamReader(in, allocator)) { assertEquals(schema, reader.getVectorSchemaRoot().getSchema()); - // Empty should return nothing. Can be called repeatedly. - reader.loadNextBatch(); + // Empty should return false + Assert.assertFalse(reader.loadNextBatch()); assertEquals(0, reader.getVectorSchemaRoot().getRowCount()); - reader.loadNextBatch(); + Assert.assertFalse(reader.loadNextBatch()); assertEquals(0, reader.getVectorSchemaRoot().getRowCount()); } } @@ -90,11 +92,11 @@ public class TestArrowStream extends BaseFileTest { Schema readSchema = reader.getVectorSchemaRoot().getSchema(); assertEquals(schema, readSchema); for (int i = 0; i < numBatches; i++) { - reader.loadNextBatch(); + assertTrue(reader.loadNextBatch()); } // TODO figure out why reader isn't getting padding bytes assertEquals(bytesWritten, reader.bytesRead() + 4); - reader.loadNextBatch(); + assertFalse(reader.loadNextBatch()); assertEquals(0, reader.getVectorSchemaRoot().getRowCount()); } } http://git-wip-us.apache.org/repos/asf/arrow/blob/01114d83/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStreamPipe.java ---------------------------------------------------------------------- diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStreamPipe.java b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStreamPipe.java index 46d4679..20d4482 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStreamPipe.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/TestArrowStreamPipe.java @@ -105,8 +105,10 @@ public class TestArrowStreamPipe { return message; } @Override - public void loadNextBatch() throws IOException { - super.loadNextBatch(); + public boolean loadNextBatch() throws IOException { + if (!super.loadNextBatch()) { + return false; + } if (!done) { VectorSchemaRoot root = getVectorSchemaRoot(); Assert.assertEquals(16, root.getRowCount()); @@ -120,6 +122,7 @@ public class TestArrowStreamPipe { } } } + return true; } }; } @@ -132,7 +135,7 @@ public class TestArrowStreamPipe { reader.getVectorSchemaRoot().getSchema().getFields().get(0).getTypeLayout().getVectorTypes().toString(), reader.getVectorSchemaRoot().getSchema().getFields().get(0).getTypeLayout().getVectors().size() > 0); while (!done) { - reader.loadNextBatch(); + assertTrue(reader.loadNextBatch()); } } catch (IOException e) { e.printStackTrace();
