Repository: arrow Updated Branches: refs/heads/master 6768f5268 -> 8bf567e63
ARROW-1136: [C++] Add null checks for invalid streams Author: Wes McKinney <wes.mckin...@twosigma.com> Closes #770 from wesm/ARROW-1136 and squashes the following commits: 6ae5cd82 [Wes McKinney] Centralize null checking bc3ec207 [Wes McKinney] Add null checks for invalid streams Project: http://git-wip-us.apache.org/repos/asf/arrow/repo Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/8bf567e6 Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/8bf567e6 Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/8bf567e6 Branch: refs/heads/master Commit: 8bf567e636e6f8a7e779a6f89ad3169f3ffa9fba Parents: 6768f52 Author: Wes McKinney <wes.mckin...@twosigma.com> Authored: Fri Jun 23 14:26:12 2017 -0400 Committer: Wes McKinney <wes.mckin...@twosigma.com> Committed: Fri Jun 23 14:26:12 2017 -0400 ---------------------------------------------------------------------- cpp/src/arrow/ipc/reader.cc | 20 ++++++++++++++------ python/pyarrow/tests/test_ipc.py | 10 ++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/arrow/blob/8bf567e6/cpp/src/arrow/ipc/reader.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 2b7b90f..7fef847 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -162,7 +162,7 @@ static inline FileBlock FileBlockFromFlatbuffer(const flatbuf::Block* block) { return FileBlock(block->offset(), block->metaDataLength(), block->bodyLength()); } -static inline std::string message_type_name(Message::Type type) { +static inline std::string FormatMessageType(Message::Type type) { switch (type) { case Message::SCHEMA: return "schema"; @@ -188,14 +188,22 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { return ReadSchema(); } - Status ReadNextMessage(Message::Type expected_type, std::shared_ptr<Message>* message) { + Status ReadNextMessage(Message::Type expected_type, bool allow_null, + std::shared_ptr<Message>* message) { RETURN_NOT_OK(ReadMessage(stream_.get(), message)); + if (!(*message) && !allow_null) { + std::stringstream ss; + ss << "Expected " << FormatMessageType(expected_type) + << " message in stream, was null or length 0"; + return Status::Invalid(ss.str()); + } + if ((*message) == nullptr) { return Status::OK(); } if ((*message)->type() != expected_type) { std::stringstream ss; - ss << "Message not expected type: " << message_type_name(expected_type) + ss << "Message not expected type: " << FormatMessageType(expected_type) << ", was: " << (*message)->type(); return Status::IOError(ss.str()); } @@ -213,7 +221,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { Status ReadNextDictionary() { std::shared_ptr<Message> message; - RETURN_NOT_OK(ReadNextMessage(Message::DICTIONARY_BATCH, &message)); + RETURN_NOT_OK(ReadNextMessage(Message::DICTIONARY_BATCH, false, &message)); std::shared_ptr<Buffer> batch_body; RETURN_NOT_OK(ReadExact(message->body_length(), &batch_body)) @@ -227,7 +235,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { Status ReadSchema() { std::shared_ptr<Message> message; - RETURN_NOT_OK(ReadNextMessage(Message::SCHEMA, &message)); + RETURN_NOT_OK(ReadNextMessage(Message::SCHEMA, false, &message)); RETURN_NOT_OK(GetDictionaryTypes(message->header(), &dictionary_types_)); @@ -243,7 +251,7 @@ class RecordBatchStreamReader::RecordBatchStreamReaderImpl { Status GetNextRecordBatch(std::shared_ptr<RecordBatch>* batch) { std::shared_ptr<Message> message; - RETURN_NOT_OK(ReadNextMessage(Message::RECORD_BATCH, &message)); + RETURN_NOT_OK(ReadNextMessage(Message::RECORD_BATCH, true, &message)); if (message == nullptr) { // End of stream http://git-wip-us.apache.org/repos/asf/arrow/blob/8bf567e6/python/pyarrow/tests/test_ipc.py ---------------------------------------------------------------------- diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index eeea39a..47ef756 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -72,6 +72,11 @@ class TestFile(MessagingTest, unittest.TestCase): def _get_writer(self, sink, schema): return pa.RecordBatchFileWriter(sink, schema) + def test_empty_file(self): + buf = io.BytesIO(b'') + with pytest.raises(pa.ArrowInvalid): + pa.open_file(buf) + def test_simple_roundtrip(self): batches = self.write_batches() file_contents = self._get_source() @@ -101,6 +106,11 @@ class TestStream(MessagingTest, unittest.TestCase): def _get_writer(self, sink, schema): return pa.RecordBatchStreamWriter(sink, schema) + def test_empty_stream(self): + buf = io.BytesIO(b'') + with pytest.raises(pa.ArrowInvalid): + pa.open_stream(buf) + def test_simple_roundtrip(self): batches = self.write_batches() file_contents = self._get_source()