Repository: arrow Updated Branches: refs/heads/master 15b874e47 -> 957a0e678
ARROW-717: [C++] Implement IPC zero-copy round trip for tensors This patch provides: ```python WriteTensor(tensor, file, &metadata_length, &body_length)); std::shared_ptr<Tensor> result; ReadTensor(offset, file, &result)); ``` Also implemented `Tensor::Equals` and did some refactoring / code simplification in compare.cc Author: Wes McKinney <wes.mckin...@twosigma.com> Closes #454 from wesm/ARROW-717 and squashes the following commits: 6c15481 [Wes McKinney] Tensor IPC read/write, and refactoring / code scrubbing Project: http://git-wip-us.apache.org/repos/asf/arrow/repo Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/957a0e67 Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/957a0e67 Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/957a0e67 Branch: refs/heads/master Commit: 957a0e67836b66f8ff4fc3fdae343553c589b53f Parents: 15b874e Author: Wes McKinney <wes.mckin...@twosigma.com> Authored: Thu Mar 30 18:03:26 2017 -0400 Committer: Wes McKinney <wes.mckin...@twosigma.com> Committed: Thu Mar 30 18:03:26 2017 -0400 ---------------------------------------------------------------------- cpp/src/arrow/buffer.cc | 6 +- cpp/src/arrow/compare.cc | 330 ++++++++++++-------------- cpp/src/arrow/compare.h | 4 + cpp/src/arrow/ipc/ipc-read-write-test.cc | 54 ++++- cpp/src/arrow/ipc/metadata.cc | 266 +++++++++++++++------ cpp/src/arrow/ipc/metadata.h | 67 +++--- cpp/src/arrow/ipc/reader.cc | 79 +++--- cpp/src/arrow/ipc/reader.h | 32 +-- cpp/src/arrow/ipc/writer.cc | 79 +++--- cpp/src/arrow/ipc/writer.h | 12 +- cpp/src/arrow/tensor-test.cc | 25 +- cpp/src/arrow/tensor.cc | 67 +++++- cpp/src/arrow/tensor.h | 18 +- cpp/src/arrow/type_traits.h | 11 + cpp/src/arrow/visitor_inline.h | 26 ++ format/Tensor.fbs | 14 +- 16 files changed, 656 insertions(+), 434 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/buffer.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/buffer.cc b/cpp/src/arrow/buffer.cc index be747e1..5962340 100644 --- a/cpp/src/arrow/buffer.cc +++ b/cpp/src/arrow/buffer.cc @@ -27,11 +27,9 @@ namespace arrow { -Buffer::Buffer(const std::shared_ptr<Buffer>& parent, int64_t offset, int64_t size) { - data_ = parent->data() + offset; - size_ = size; +Buffer::Buffer(const std::shared_ptr<Buffer>& parent, int64_t offset, int64_t size) + : Buffer(parent->data() + offset, size) { parent_ = parent; - capacity_ = size; } Buffer::~Buffer() {} http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/compare.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index f786222..c2580b4 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -25,6 +25,7 @@ #include "arrow/array.h" #include "arrow/status.h" +#include "arrow/tensor.h" #include "arrow/type.h" #include "arrow/type_traits.h" #include "arrow/util/bit-util.h" @@ -36,7 +37,7 @@ namespace arrow { // ---------------------------------------------------------------------- // Public method implementations -class RangeEqualsVisitor : public ArrayVisitor { +class RangeEqualsVisitor { public: RangeEqualsVisitor(const Array& right, int64_t left_start_idx, int64_t left_end_idx, int64_t right_start_idx) @@ -46,12 +47,6 @@ class RangeEqualsVisitor : public ArrayVisitor { right_start_idx_(right_start_idx), result_(false) {} - Status Visit(const NullArray& left) override { - UNUSED(left); - result_ = true; - return Status::OK(); - } - template <typename ArrayType> inline Status CompareValues(const ArrayType& left) { const auto& right = static_cast<const ArrayType&>(right_); @@ -96,108 +91,6 @@ class RangeEqualsVisitor : public ArrayVisitor { return true; } - Status Visit(const BooleanArray& left) override { - return CompareValues<BooleanArray>(left); - } - - Status Visit(const Int8Array& left) override { return CompareValues<Int8Array>(left); } - - Status Visit(const Int16Array& left) override { - return CompareValues<Int16Array>(left); - } - Status Visit(const Int32Array& left) override { - return CompareValues<Int32Array>(left); - } - Status Visit(const Int64Array& left) override { - return CompareValues<Int64Array>(left); - } - Status Visit(const UInt8Array& left) override { - return CompareValues<UInt8Array>(left); - } - Status Visit(const UInt16Array& left) override { - return CompareValues<UInt16Array>(left); - } - Status Visit(const UInt32Array& left) override { - return CompareValues<UInt32Array>(left); - } - Status Visit(const UInt64Array& left) override { - return CompareValues<UInt64Array>(left); - } - Status Visit(const FloatArray& left) override { - return CompareValues<FloatArray>(left); - } - Status Visit(const DoubleArray& left) override { - return CompareValues<DoubleArray>(left); - } - - Status Visit(const HalfFloatArray& left) override { - return Status::NotImplemented("Half float type"); - } - - Status Visit(const StringArray& left) override { - result_ = CompareBinaryRange(left); - return Status::OK(); - } - - Status Visit(const BinaryArray& left) override { - result_ = CompareBinaryRange(left); - return Status::OK(); - } - - Status Visit(const FixedWidthBinaryArray& left) override { - const auto& right = static_cast<const FixedWidthBinaryArray&>(right_); - - int32_t width = left.byte_width(); - - const uint8_t* left_data = left.raw_data() + left.offset() * width; - const uint8_t* right_data = right.raw_data() + right.offset() * width; - - for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_; - ++i, ++o_i) { - const bool is_null = left.IsNull(i); - if (is_null != right.IsNull(o_i)) { - result_ = false; - return Status::OK(); - } - if (is_null) continue; - - if (std::memcmp(left_data + width * i, right_data + width * o_i, width)) { - result_ = false; - return Status::OK(); - } - } - result_ = true; - return Status::OK(); - } - - Status Visit(const Date32Array& left) override { - return CompareValues<Date32Array>(left); - } - - Status Visit(const Date64Array& left) override { - return CompareValues<Date64Array>(left); - } - - Status Visit(const Time32Array& left) override { - return CompareValues<Time32Array>(left); - } - - Status Visit(const Time64Array& left) override { - return CompareValues<Time64Array>(left); - } - - Status Visit(const TimestampArray& left) override { - return CompareValues<TimestampArray>(left); - } - - Status Visit(const IntervalArray& left) override { - return CompareValues<IntervalArray>(left); - } - - Status Visit(const DecimalArray& left) override { - return Status::NotImplemented("Decimal type"); - } - bool CompareLists(const ListArray& left) { const auto& right = static_cast<const ListArray&>(right_); @@ -225,11 +118,6 @@ class RangeEqualsVisitor : public ArrayVisitor { return true; } - Status Visit(const ListArray& left) override { - result_ = CompareLists(left); - return Status::OK(); - } - bool CompareStructs(const StructArray& left) { const auto& right = static_cast<const StructArray&>(right_); bool equal_fields = true; @@ -251,11 +139,6 @@ class RangeEqualsVisitor : public ArrayVisitor { return true; } - Status Visit(const StructArray& left) override { - result_ = CompareStructs(left); - return Status::OK(); - } - bool CompareUnions(const UnionArray& left) const { const auto& right = static_cast<const UnionArray&>(right_); @@ -314,12 +197,73 @@ class RangeEqualsVisitor : public ArrayVisitor { return true; } - Status Visit(const UnionArray& left) override { + Status Visit(const BinaryArray& left) { + result_ = CompareBinaryRange(left); + return Status::OK(); + } + + Status Visit(const FixedWidthBinaryArray& left) { + const auto& right = static_cast<const FixedWidthBinaryArray&>(right_); + + int32_t width = left.byte_width(); + + const uint8_t* left_data = nullptr; + const uint8_t* right_data = nullptr; + + if (left.data()) { left_data = left.raw_data() + left.offset() * width; } + + if (right.data()) { right_data = right.raw_data() + right.offset() * width; } + + for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_; + ++i, ++o_i) { + const bool is_null = left.IsNull(i); + if (is_null != right.IsNull(o_i)) { + result_ = false; + return Status::OK(); + } + if (is_null) continue; + + if (std::memcmp(left_data + width * i, right_data + width * o_i, width)) { + result_ = false; + return Status::OK(); + } + } + result_ = true; + return Status::OK(); + } + + Status Visit(const NullArray& left) { + UNUSED(left); + result_ = true; + return Status::OK(); + } + + template <typename T> + typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value, Status>::type Visit( + const T& left) { + return CompareValues<T>(left); + } + + Status Visit(const DecimalArray& left) { + return Status::NotImplemented("Decimal type"); + } + + Status Visit(const ListArray& left) { + result_ = CompareLists(left); + return Status::OK(); + } + + Status Visit(const StructArray& left) { + result_ = CompareStructs(left); + return Status::OK(); + } + + Status Visit(const UnionArray& left) { result_ = CompareUnions(left); return Status::OK(); } - Status Visit(const DictionaryArray& left) override { + Status Visit(const DictionaryArray& left) { const auto& right = static_cast<const DictionaryArray&>(right_); if (!left.dictionary()->Equals(right.dictionary())) { result_ = false; @@ -346,9 +290,9 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor { explicit ArrayEqualsVisitor(const Array& right) : RangeEqualsVisitor(right, 0, right.length(), 0) {} - Status Visit(const NullArray& left) override { return Status::OK(); } + Status Visit(const NullArray& left) { return Status::OK(); } - Status Visit(const BooleanArray& left) override { + Status Visit(const BooleanArray& left) { const auto& right = static_cast<const BooleanArray&>(right_); if (left.null_count() > 0) { const uint8_t* left_data = left.data()->data(); @@ -372,64 +316,39 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor { bool IsEqualPrimitive(const PrimitiveArray& left) { const auto& right = static_cast<const PrimitiveArray&>(right_); const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type()); - const int value_byte_size = size_meta.bit_width() / 8; - DCHECK_GT(value_byte_size, 0); + const int byte_width = size_meta.bit_width() / 8; + + const uint8_t* left_data = nullptr; + const uint8_t* right_data = nullptr; + + if (left.data()) { left_data = left.data()->data() + left.offset() * byte_width; } - const uint8_t* left_data = left.data()->data() + left.offset() * value_byte_size; - const uint8_t* right_data = right.data()->data() + right.offset() * value_byte_size; + if (right.data()) { right_data = right.data()->data() + right.offset() * byte_width; } if (left.null_count() > 0) { for (int64_t i = 0; i < left.length(); ++i) { - if (!left.IsNull(i) && memcmp(left_data, right_data, value_byte_size)) { + if (!left.IsNull(i) && memcmp(left_data, right_data, byte_width)) { return false; } - left_data += value_byte_size; - right_data += value_byte_size; + left_data += byte_width; + right_data += byte_width; } return true; } else { return memcmp(left_data, right_data, - static_cast<size_t>(value_byte_size * left.length())) == 0; + static_cast<size_t>(byte_width * left.length())) == 0; } } - Status ComparePrimitive(const PrimitiveArray& left) { + template <typename T> + typename std::enable_if<std::is_base_of<PrimitiveArray, T>::value && + !std::is_base_of<BooleanArray, T>::value, + Status>::type + Visit(const T& left) { result_ = IsEqualPrimitive(left); return Status::OK(); } - Status Visit(const Int8Array& left) override { return ComparePrimitive(left); } - - Status Visit(const Int16Array& left) override { return ComparePrimitive(left); } - - Status Visit(const Int32Array& left) override { return ComparePrimitive(left); } - - Status Visit(const Int64Array& left) override { return ComparePrimitive(left); } - - Status Visit(const UInt8Array& left) override { return ComparePrimitive(left); } - - Status Visit(const UInt16Array& left) override { return ComparePrimitive(left); } - - Status Visit(const UInt32Array& left) override { return ComparePrimitive(left); } - - Status Visit(const UInt64Array& left) override { return ComparePrimitive(left); } - - Status Visit(const FloatArray& left) override { return ComparePrimitive(left); } - - Status Visit(const DoubleArray& left) override { return ComparePrimitive(left); } - - Status Visit(const Date32Array& left) override { return ComparePrimitive(left); } - - Status Visit(const Date64Array& left) override { return ComparePrimitive(left); } - - Status Visit(const Time32Array& left) override { return ComparePrimitive(left); } - - Status Visit(const Time64Array& left) override { return ComparePrimitive(left); } - - Status Visit(const TimestampArray& left) override { return ComparePrimitive(left); } - - Status Visit(const IntervalArray& left) override { return ComparePrimitive(left); } - template <typename ArrayType> bool ValueOffsetsEqual(const ArrayType& left) { const auto& right = static_cast<const ArrayType&>(right_); @@ -494,17 +413,12 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor { } } - Status Visit(const StringArray& left) override { - result_ = CompareBinary(left); - return Status::OK(); - } - - Status Visit(const BinaryArray& left) override { + Status Visit(const BinaryArray& left) { result_ = CompareBinary(left); return Status::OK(); } - Status Visit(const ListArray& left) override { + Status Visit(const ListArray& left) { const auto& right = static_cast<const ListArray&>(right_); bool equal_offsets = ValueOffsetsEqual<ListArray>(left); if (!equal_offsets) { @@ -523,7 +437,7 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor { return Status::OK(); } - Status Visit(const DictionaryArray& left) override { + Status Visit(const DictionaryArray& left) { const auto& right = static_cast<const DictionaryArray&>(right_); if (!left.dictionary()->Equals(right.dictionary())) { result_ = false; @@ -532,6 +446,13 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor { } return Status::OK(); } + + template <typename T> + typename std::enable_if<std::is_base_of<NestedType, typename T::TypeClass>::value, + Status>::type + Visit(const T& left) { + return RangeEqualsVisitor::Visit(left); + } }; template <typename TYPE> @@ -560,14 +481,15 @@ inline bool FloatingApproxEquals( class ApproxEqualsVisitor : public ArrayEqualsVisitor { public: using ArrayEqualsVisitor::ArrayEqualsVisitor; + using ArrayEqualsVisitor::Visit; - Status Visit(const FloatArray& left) override { + Status Visit(const FloatArray& left) { result_ = FloatingApproxEquals<FloatType>(left, static_cast<const FloatArray&>(right_)); return Status::OK(); } - Status Visit(const DoubleArray& left) override { + Status Visit(const DoubleArray& left) { result_ = FloatingApproxEquals<DoubleType>(left, static_cast<const DoubleArray&>(right_)); return Status::OK(); @@ -586,7 +508,8 @@ static bool BaseDataEquals(const Array& left, const Array& right) { return true; } -Status ArrayEquals(const Array& left, const Array& right, bool* are_equal) { +template <typename VISITOR> +inline Status ArrayEqualsImpl(const Array& left, const Array& right, bool* are_equal) { // The arrays are the same object if (&left == &right) { *are_equal = true; @@ -595,13 +518,21 @@ Status ArrayEquals(const Array& left, const Array& right, bool* are_equal) { } else if (left.length() == 0) { *are_equal = true; } else { - ArrayEqualsVisitor visitor(right); - RETURN_NOT_OK(left.Accept(&visitor)); + VISITOR visitor(right); + RETURN_NOT_OK(VisitArrayInline(left, &visitor)); *are_equal = visitor.result(); } return Status::OK(); } +Status ArrayEquals(const Array& left, const Array& right, bool* are_equal) { + return ArrayEqualsImpl<ArrayEqualsVisitor>(left, right, are_equal); +} + +Status ArrayApproxEquals(const Array& left, const Array& right, bool* are_equal) { + return ArrayEqualsImpl<ApproxEqualsVisitor>(left, right, are_equal); +} + Status ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx, int64_t left_end_idx, int64_t right_start_idx, bool* are_equal) { if (&left == &right) { @@ -612,23 +543,56 @@ Status ArrayRangeEquals(const Array& left, const Array& right, int64_t left_star *are_equal = true; } else { RangeEqualsVisitor visitor(right, left_start_idx, left_end_idx, right_start_idx); - RETURN_NOT_OK(left.Accept(&visitor)); + RETURN_NOT_OK(VisitArrayInline(left, &visitor)); *are_equal = visitor.result(); } return Status::OK(); } -Status ArrayApproxEquals(const Array& left, const Array& right, bool* are_equal) { +// ---------------------------------------------------------------------- +// Implement TensorEquals + +class TensorEqualsVisitor { + public: + explicit TensorEqualsVisitor(const Tensor& right) : right_(right) {} + + template <typename TensorType> + Status Visit(const TensorType& left) { + const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type()); + const int byte_width = size_meta.bit_width() / 8; + DCHECK_GT(byte_width, 0); + + const uint8_t* left_data = left.data()->data(); + const uint8_t* right_data = right_.data()->data(); + + result_ = + memcmp(left_data, right_data, static_cast<size_t>(byte_width * left.size())) == 0; + return Status::OK(); + } + + bool result() const { return result_; } + + protected: + const Tensor& right_; + bool result_; +}; + +Status TensorEquals(const Tensor& left, const Tensor& right, bool* are_equal) { // The arrays are the same object if (&left == &right) { *are_equal = true; - } else if (!BaseDataEquals(left, right)) { + } else if (left.type_enum() != right.type_enum()) { *are_equal = false; - } else if (left.length() == 0) { + } else if (left.size() == 0) { *are_equal = true; } else { - ApproxEqualsVisitor visitor(right); - RETURN_NOT_OK(left.Accept(&visitor)); + if (!left.is_contiguous() || !right.is_contiguous()) { + return Status::NotImplemented( + "Comparison not implemented for non-contiguous tensors"); + } + + TensorEqualsVisitor visitor(right); + RETURN_NOT_OK(VisitTensorInline(left, &visitor)); *are_equal = visitor.result(); } return Status::OK(); http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/compare.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compare.h b/cpp/src/arrow/compare.h index 1ddf049..522b11d 100644 --- a/cpp/src/arrow/compare.h +++ b/cpp/src/arrow/compare.h @@ -29,10 +29,14 @@ namespace arrow { class Array; struct DataType; class Status; +class Tensor; /// Returns true if the arrays are exactly equal Status ARROW_EXPORT ArrayEquals(const Array& left, const Array& right, bool* are_equal); +Status ARROW_EXPORT TensorEquals( + const Tensor& left, const Tensor& right, bool* are_equal); + /// Returns true if the arrays are approximately equal. For non-floating point /// types, this is equivalent to ArrayEquals(left, right) Status ARROW_EXPORT ArrayApproxEquals( http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/ipc-read-write-test.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/ipc-read-write-test.cc b/cpp/src/arrow/ipc/ipc-read-write-test.cc index 6ddda3f..74ca017 100644 --- a/cpp/src/arrow/ipc/ipc-read-write-test.cc +++ b/cpp/src/arrow/ipc/ipc-read-write-test.cc @@ -25,16 +25,16 @@ #include "gtest/gtest.h" #include "arrow/array.h" +#include "arrow/buffer.h" #include "arrow/io/memory.h" #include "arrow/io/test-common.h" #include "arrow/ipc/api.h" #include "arrow/ipc/test-common.h" #include "arrow/ipc/util.h" - -#include "arrow/buffer.h" #include "arrow/memory_pool.h" #include "arrow/pretty_print.h" #include "arrow/status.h" +#include "arrow/tensor.h" #include "arrow/test-util.h" #include "arrow/util/bit-util.h" @@ -56,13 +56,10 @@ class TestSchemaMetadata : public ::testing::Test { ASSERT_EQ(Message::SCHEMA, message->type()); - auto schema_msg = std::make_shared<SchemaMetadata>(message); - ASSERT_EQ(schema.num_fields(), schema_msg->num_fields()); - DictionaryMemo empty_memo; std::shared_ptr<Schema> schema2; - ASSERT_OK(schema_msg->GetSchema(empty_memo, &schema2)); + ASSERT_OK(GetSchema(message->header(), empty_memo, &schema2)); AssertSchemaEqual(schema, *schema2); } @@ -90,7 +87,7 @@ TEST_F(TestSchemaMetadata, PrimitiveFields) { } TEST_F(TestSchemaMetadata, NestedFields) { - auto type = std::make_shared<ListType>(std::make_shared<Int32Type>()); + auto type = list(int32()); auto f0 = field("f0", type); std::shared_ptr<StructType> type2( @@ -532,7 +529,6 @@ TEST_F(TestIpcRoundTrip, LargeRecordBatch) { // 512 MB constexpr int64_t kBufferSize = 1 << 29; - ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_)); std::shared_ptr<RecordBatch> result; @@ -580,5 +576,47 @@ TEST_F(TestFileFormat, DictionaryRoundTrip) { CheckBatchDictionaries(*out_batches[0]); } +class TestTensorRoundTrip : public ::testing::Test, public IpcTestFixture { + public: + void SetUp() { pool_ = default_memory_pool(); } + void TearDown() { io::MemoryMapFixture::TearDown(); } + + void CheckTensorRoundTrip(const Tensor& tensor) { + int32_t metadata_length; + int64_t body_length; + + ASSERT_OK(mmap_->Seek(0)); + + ASSERT_OK(WriteTensor(tensor, mmap_.get(), &metadata_length, &body_length)); + + std::shared_ptr<Tensor> result; + ASSERT_OK(ReadTensor(0, mmap_.get(), &result)); + + ASSERT_TRUE(tensor.Equals(*result)); + } +}; + +TEST_F(TestTensorRoundTrip, BasicRoundtrip) { + std::string path = "test-write-tensor"; + constexpr int64_t kBufferSize = 1 << 20; + ASSERT_OK(io::MemoryMapFixture::InitMemoryMap(kBufferSize, path, &mmap_)); + + std::vector<int64_t> shape = {4, 6}; + std::vector<int64_t> strides = {48, 8}; + std::vector<std::string> dim_names = {"foo", "bar"}; + int64_t size = 24; + + std::vector<int64_t> values; + test::randint<int64_t>(size, 0, 100, &values); + + auto data = test::GetBufferFromVector(values); + + Int64Tensor t0(data, shape, strides, dim_names); + Int64Tensor tzero(data, {}, {}, {}); + + CheckTensorRoundTrip(t0); + CheckTensorRoundTrip(tzero); +} + } // namespace ipc } // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/metadata.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/metadata.cc b/cpp/src/arrow/ipc/metadata.cc index 6d9fabd..076a6e7 100644 --- a/cpp/src/arrow/ipc/metadata.cc +++ b/cpp/src/arrow/ipc/metadata.cc @@ -20,6 +20,7 @@ #include <cstdint> #include <memory> #include <sstream> +#include <string> #include <vector> #include "flatbuffers/flatbuffers.h" @@ -29,7 +30,10 @@ #include "arrow/io/interfaces.h" #include "arrow/ipc/File_generated.h" #include "arrow/ipc/Message_generated.h" +#include "arrow/ipc/Tensor_generated.h" +#include "arrow/ipc/util.h" #include "arrow/status.h" +#include "arrow/tensor.h" #include "arrow/type.h" namespace arrow { @@ -418,6 +422,46 @@ static Status TypeToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type, return Status::OK(); } +static Status TensorTypeToFlatbuffer(FBB& fbb, const std::shared_ptr<DataType>& type, + flatbuf::Type* out_type, Offset* offset) { + switch (type->type) { + case Type::UINT8: + INT_TO_FB_CASE(8, false); + case Type::INT8: + INT_TO_FB_CASE(8, true); + case Type::UINT16: + INT_TO_FB_CASE(16, false); + case Type::INT16: + INT_TO_FB_CASE(16, true); + case Type::UINT32: + INT_TO_FB_CASE(32, false); + case Type::INT32: + INT_TO_FB_CASE(32, true); + case Type::UINT64: + INT_TO_FB_CASE(64, false); + case Type::INT64: + INT_TO_FB_CASE(64, true); + case Type::HALF_FLOAT: + *out_type = flatbuf::Type_FloatingPoint; + *offset = FloatToFlatbuffer(fbb, flatbuf::Precision_HALF); + break; + case Type::FLOAT: + *out_type = flatbuf::Type_FloatingPoint; + *offset = FloatToFlatbuffer(fbb, flatbuf::Precision_SINGLE); + break; + case Type::DOUBLE: + *out_type = flatbuf::Type_FloatingPoint; + *offset = FloatToFlatbuffer(fbb, flatbuf::Precision_DOUBLE); + break; + default: + *out_type = flatbuf::Type_NONE; // Make clang-tidy happy + std::stringstream ss; + ss << "Unable to convert type: " << type->ToString() << std::endl; + return Status::NotImplemented(ss.str()); + } + return Status::OK(); +} + static DictionaryOffset GetDictionaryEncoding( FBB& fbb, const DictionaryType& type, DictionaryMemo* memo) { int64_t dictionary_id = memo->GetId(type.dictionary()); @@ -552,7 +596,7 @@ static Status WriteFlatbufferBuilder(FBB& fbb, std::shared_ptr<Buffer>* out) { return Status::OK(); } -static Status WriteMessage(FBB& fbb, flatbuf::MessageHeader header_type, +static Status WriteFBMessage(FBB& fbb, flatbuf::MessageHeader header_type, flatbuffers::Offset<void> header, int64_t body_length, std::shared_ptr<Buffer>* out) { auto message = flatbuf::CreateMessage(fbb, kMetadataVersion, header_type, header, body_length); @@ -565,7 +609,7 @@ Status WriteSchemaMessage( FBB fbb; flatbuffers::Offset<flatbuf::Schema> fb_schema; RETURN_NOT_OK(SchemaToFlatbuffer(fbb, schema, dictionary_memo, &fb_schema)); - return WriteMessage(fbb, flatbuf::MessageHeader_Schema, fb_schema.Union(), 0, out); + return WriteFBMessage(fbb, flatbuf::MessageHeader_Schema, fb_schema.Union(), 0, out); } using FieldNodeVector = @@ -620,10 +664,39 @@ Status WriteRecordBatchMessage(int64_t length, int64_t body_length, FBB fbb; RecordBatchOffset record_batch; RETURN_NOT_OK(MakeRecordBatch(fbb, length, body_length, nodes, buffers, &record_batch)); - return WriteMessage( + return WriteFBMessage( fbb, flatbuf::MessageHeader_RecordBatch, record_batch.Union(), body_length, out); } +Status WriteTensorMessage( + const Tensor& tensor, int64_t buffer_start_offset, std::shared_ptr<Buffer>* out) { + using TensorDimOffset = flatbuffers::Offset<flatbuf::TensorDim>; + using TensorOffset = flatbuffers::Offset<flatbuf::Tensor>; + + FBB fbb; + + flatbuf::Type fb_type_type; + Offset fb_type; + RETURN_NOT_OK(TensorTypeToFlatbuffer(fbb, tensor.type(), &fb_type_type, &fb_type)); + + std::vector<TensorDimOffset> dims; + for (int i = 0; i < tensor.ndim(); ++i) { + FBString name = fbb.CreateString(tensor.dim_name(i)); + dims.push_back(flatbuf::CreateTensorDim(fbb, tensor.shape()[i], name)); + } + + auto fb_shape = fbb.CreateVector(dims); + auto fb_strides = fbb.CreateVector(tensor.strides()); + int64_t body_length = tensor.data()->size(); + flatbuf::Buffer buffer(-1, buffer_start_offset, body_length); + + TensorOffset fb_tensor = + flatbuf::CreateTensor(fbb, fb_type_type, fb_type, fb_shape, fb_strides, &buffer); + + return WriteFBMessage( + fbb, flatbuf::MessageHeader_Tensor, fb_tensor.Union(), body_length, out); +} + Status WriteDictionaryMessage(int64_t id, int64_t length, int64_t body_length, const std::vector<FieldMetadata>& nodes, const std::vector<BufferMetadata>& buffers, std::shared_ptr<Buffer>* out) { @@ -631,7 +704,7 @@ Status WriteDictionaryMessage(int64_t id, int64_t length, int64_t body_length, RecordBatchOffset record_batch; RETURN_NOT_OK(MakeRecordBatch(fbb, length, body_length, nodes, buffers, &record_batch)); auto dictionary_batch = flatbuf::CreateDictionaryBatch(fbb, id, record_batch).Union(); - return WriteMessage( + return WriteFBMessage( fbb, flatbuf::MessageHeader_DictionaryBatch, dictionary_batch, body_length, out); } @@ -746,6 +819,8 @@ class Message::MessageImpl { return Message::DICTIONARY_BATCH; case flatbuf::MessageHeader_RecordBatch: return Message::RECORD_BATCH; + case flatbuf::MessageHeader_Tensor: + return Message::TENSOR; default: return Message::NONE; } @@ -790,95 +865,78 @@ const void* Message::header() const { } // ---------------------------------------------------------------------- -// SchemaMetadata - -class MessageHolder { - public: - void set_message(const std::shared_ptr<Message>& message) { message_ = message; } - void set_buffer(const std::shared_ptr<Buffer>& buffer) { buffer_ = buffer; } - - protected: - // Possible parents, owns the flatbuffer data - std::shared_ptr<Message> message_; - std::shared_ptr<Buffer> buffer_; -}; - -class SchemaMetadata::SchemaMetadataImpl : public MessageHolder { - public: - explicit SchemaMetadataImpl(const void* schema) - : schema_(static_cast<const flatbuf::Schema*>(schema)) {} - - const flatbuf::Field* get_field(int i) const { return schema_->fields()->Get(i); } - int num_fields() const { return schema_->fields()->size(); } - - Status VisitField(const flatbuf::Field* field, DictionaryTypeMap* id_to_field) const { - const flatbuf::DictionaryEncoding* dict_metadata = field->dictionary(); - if (dict_metadata == nullptr) { - // Field is not dictionary encoded. Visit children - auto children = field->children(); - for (flatbuffers::uoffset_t i = 0; i < children->size(); ++i) { - RETURN_NOT_OK(VisitField(children->Get(i), id_to_field)); - } - } else { - // Field is dictionary encoded. Construct the data type for the - // dictionary (no descendents can be dictionary encoded) - std::shared_ptr<Field> dictionary_field; - RETURN_NOT_OK(FieldFromFlatbufferDictionary(field, &dictionary_field)); - (*id_to_field)[dict_metadata->id()] = dictionary_field; +static Status VisitField(const flatbuf::Field* field, DictionaryTypeMap* id_to_field) { + const flatbuf::DictionaryEncoding* dict_metadata = field->dictionary(); + if (dict_metadata == nullptr) { + // Field is not dictionary encoded. Visit children + auto children = field->children(); + for (flatbuffers::uoffset_t i = 0; i < children->size(); ++i) { + RETURN_NOT_OK(VisitField(children->Get(i), id_to_field)); } - return Status::OK(); + } else { + // Field is dictionary encoded. Construct the data type for the + // dictionary (no descendents can be dictionary encoded) + std::shared_ptr<Field> dictionary_field; + RETURN_NOT_OK(FieldFromFlatbufferDictionary(field, &dictionary_field)); + (*id_to_field)[dict_metadata->id()] = dictionary_field; } + return Status::OK(); +} - Status GetDictionaryTypes(DictionaryTypeMap* id_to_field) const { - for (int i = 0; i < num_fields(); ++i) { - RETURN_NOT_OK(VisitField(get_field(i), id_to_field)); - } - return Status::OK(); +Status GetDictionaryTypes(const void* opaque_schema, DictionaryTypeMap* id_to_field) { + auto schema = static_cast<const flatbuf::Schema*>(opaque_schema); + int num_fields = static_cast<int>(schema->fields()->size()); + for (int i = 0; i < num_fields; ++i) { + RETURN_NOT_OK(VisitField(schema->fields()->Get(i), id_to_field)); } - - private: - const flatbuf::Schema* schema_; -}; - -SchemaMetadata::SchemaMetadata(const std::shared_ptr<Message>& message) - : SchemaMetadata(message->impl_->header()) { - impl_->set_message(message); + return Status::OK(); } -SchemaMetadata::SchemaMetadata(const void* header) { - impl_.reset(new SchemaMetadataImpl(header)); -} +Status GetSchema(const void* opaque_schema, const DictionaryMemo& dictionary_memo, + std::shared_ptr<Schema>* out) { + auto schema = static_cast<const flatbuf::Schema*>(opaque_schema); + int num_fields = static_cast<int>(schema->fields()->size()); -SchemaMetadata::SchemaMetadata(const std::shared_ptr<Buffer>& buffer, int64_t offset) - : SchemaMetadata(buffer->data() + offset) { - // Preserve ownership - impl_->set_buffer(buffer); + std::vector<std::shared_ptr<Field>> fields(num_fields); + for (int i = 0; i < num_fields; ++i) { + const flatbuf::Field* field = schema->fields()->Get(i); + RETURN_NOT_OK(FieldFromFlatbuffer(field, dictionary_memo, &fields[i])); + } + *out = std::make_shared<Schema>(fields); + return Status::OK(); } -SchemaMetadata::~SchemaMetadata() {} +Status GetTensorMetadata(const void* opaque_tensor, std::shared_ptr<DataType>* type, + std::vector<int64_t>* shape, std::vector<int64_t>* strides, + std::vector<std::string>* dim_names) { + auto tensor = static_cast<const flatbuf::Tensor*>(opaque_tensor); -int SchemaMetadata::num_fields() const { - return impl_->num_fields(); -} + int ndim = static_cast<int>(tensor->shape()->size()); -Status SchemaMetadata::GetDictionaryTypes(DictionaryTypeMap* id_to_field) const { - return impl_->GetDictionaryTypes(id_to_field); -} + for (int i = 0; i < ndim; ++i) { + auto dim = tensor->shape()->Get(i); -Status SchemaMetadata::GetSchema( - const DictionaryMemo& dictionary_memo, std::shared_ptr<Schema>* out) const { - std::vector<std::shared_ptr<Field>> fields(num_fields()); - for (int i = 0; i < this->num_fields(); ++i) { - const flatbuf::Field* field = impl_->get_field(i); - RETURN_NOT_OK(FieldFromFlatbuffer(field, dictionary_memo, &fields[i])); + shape->push_back(dim->size()); + auto fb_name = dim->name(); + if (fb_name == 0) { + dim_names->push_back(""); + } else { + dim_names->push_back(fb_name->str()); + } } - *out = std::make_shared<Schema>(fields); - return Status::OK(); + + if (tensor->strides()->size() > 0) { + for (int i = 0; i < ndim; ++i) { + strides->push_back(tensor->strides()->Get(i)); + } + } + + return TypeFromFlatbuffer(tensor->type_type(), tensor->type(), {}, type); } // ---------------------------------------------------------------------- -// Conveniences +// Read and write messages Status ReadMessage(int64_t offset, int32_t metadata_length, io::RandomAccessFile* file, std::shared_ptr<Message>* message) { @@ -896,5 +954,61 @@ Status ReadMessage(int64_t offset, int32_t metadata_length, io::RandomAccessFile return Message::Open(buffer, 4, message); } +Status ReadMessage(io::InputStream* file, std::shared_ptr<Message>* message) { + std::shared_ptr<Buffer> buffer; + RETURN_NOT_OK(file->Read(sizeof(int32_t), &buffer)); + + if (buffer->size() != sizeof(int32_t)) { + *message = nullptr; + return Status::OK(); + } + + int32_t message_length = *reinterpret_cast<const int32_t*>(buffer->data()); + + if (message_length == 0) { + // Optional 0 EOS control message + *message = nullptr; + return Status::OK(); + } + + RETURN_NOT_OK(file->Read(message_length, &buffer)); + if (buffer->size() != message_length) { + return Status::IOError("Unexpected end of stream trying to read message"); + } + + return Message::Open(buffer, 0, message); +} + +Status WriteMessage( + const Buffer& message, io::OutputStream* file, int32_t* message_length) { + // Need to write 4 bytes (message size), the message, plus padding to + // end on an 8-byte offset + int64_t start_offset; + RETURN_NOT_OK(file->Tell(&start_offset)); + + int32_t padded_message_length = static_cast<int32_t>(message.size()) + 4; + const int32_t remainder = + (padded_message_length + static_cast<int32_t>(start_offset)) % 8; + if (remainder != 0) { padded_message_length += 8 - remainder; } + + // The returned message size includes the length prefix, the flatbuffer, + // plus padding + *message_length = padded_message_length; + + // Write the flatbuffer size prefix including padding + int32_t flatbuffer_size = padded_message_length - 4; + RETURN_NOT_OK( + file->Write(reinterpret_cast<const uint8_t*>(&flatbuffer_size), sizeof(int32_t))); + + // Write the flatbuffer + RETURN_NOT_OK(file->Write(message.data(), message.size())); + + // Write any padding + int32_t padding = padded_message_length - static_cast<int32_t>(message.size()) - 4; + if (padding > 0) { RETURN_NOT_OK(file->Write(kPaddingBytes, padding)); } + + return Status::OK(); +} + } // namespace ipc } // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/metadata.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/metadata.h b/cpp/src/arrow/ipc/metadata.h index 798abdc..fac4a70 100644 --- a/cpp/src/arrow/ipc/metadata.h +++ b/cpp/src/arrow/ipc/metadata.h @@ -22,6 +22,7 @@ #include <cstdint> #include <memory> +#include <string> #include <unordered_map> #include <vector> @@ -37,9 +38,11 @@ struct DataType; struct Field; class Schema; class Status; +class Tensor; namespace io { +class InputStream; class OutputStream; class RandomAccessFile; @@ -53,7 +56,7 @@ struct MetadataVersion { static constexpr const char* kArrowMagicBytes = "ARROW1"; -struct ARROW_EXPORT FileBlock { +struct FileBlock { FileBlock() {} FileBlock(int64_t offset, int32_t metadata_length, int64_t body_length) : offset(offset), metadata_length(metadata_length), body_length(body_length) {} @@ -104,44 +107,25 @@ class DictionaryMemo { class Message; -// Container for serialized Schema metadata contained in an IPC message -class ARROW_EXPORT SchemaMetadata { - public: - explicit SchemaMetadata(const void* header); - explicit SchemaMetadata(const std::shared_ptr<Message>& message); - SchemaMetadata(const std::shared_ptr<Buffer>& message, int64_t offset); - - ~SchemaMetadata(); - - int num_fields() const; - - // Retrieve a list of all the dictionary ids and types required by the schema for - // reconstruction. The presumption is that these will be loaded either from - // the stream or file (or they may already be somewhere else in memory) - Status GetDictionaryTypes(DictionaryTypeMap* id_to_field) const; +// Retrieve a list of all the dictionary ids and types required by the schema for +// reconstruction. The presumption is that these will be loaded either from +// the stream or file (or they may already be somewhere else in memory) +Status GetDictionaryTypes(const void* opaque_schema, DictionaryTypeMap* id_to_field); - // Construct a complete Schema from the message. May be expensive for very - // large schemas if you are only interested in a few fields - Status GetSchema( - const DictionaryMemo& dictionary_memo, std::shared_ptr<Schema>* out) const; - - private: - class SchemaMetadataImpl; - std::unique_ptr<SchemaMetadataImpl> impl_; - - DISALLOW_COPY_AND_ASSIGN(SchemaMetadata); -}; +// Construct a complete Schema from the message. May be expensive for very +// large schemas if you are only interested in a few fields +Status GetSchema(const void* opaque_schema, const DictionaryMemo& dictionary_memo, + std::shared_ptr<Schema>* out); -struct ARROW_EXPORT BufferMetadata { - int32_t page; - int64_t offset; - int64_t length; -}; +Status GetTensorMetadata(const void* opaque_tensor, std::shared_ptr<DataType>* type, + std::vector<int64_t>* shape, std::vector<int64_t>* strides, + std::vector<std::string>* dim_names); class ARROW_EXPORT Message { public: + enum Type { NONE, SCHEMA, DICTIONARY_BATCH, RECORD_BATCH, TENSOR }; + ~Message(); - enum Type { NONE, SCHEMA, DICTIONARY_BATCH, RECORD_BATCH }; static Status Open(const std::shared_ptr<Buffer>& buffer, int64_t offset, std::shared_ptr<Message>* out); @@ -155,9 +139,6 @@ class ARROW_EXPORT Message { private: Message(const std::shared_ptr<Buffer>& buffer, int64_t offset); - friend class DictionaryBatchMetadata; - friend class SchemaMetadata; - // Hide serialization details from user API class MessageImpl; std::unique_ptr<MessageImpl> impl_; @@ -179,6 +160,17 @@ class ARROW_EXPORT Message { Status ReadMessage(int64_t offset, int32_t metadata_length, io::RandomAccessFile* file, std::shared_ptr<Message>* message); +/// Read length-prefixed message with as-yet unknown length. Returns nullptr if +/// there are not enough bytes available or the message length is 0 (e.g. EOS +/// in a stream) +Status ReadMessage(io::InputStream* stream, std::shared_ptr<Message>* message); + +/// Write a serialized message with a length-prefix and padding to an 8-byte offset +/// +/// <message_size: int32><message: const void*><padding> +Status WriteMessage( + const Buffer& message, io::OutputStream* file, int32_t* message_length); + // Serialize arrow::Schema as a Flatbuffer // // \param[in] schema a Schema instance @@ -193,6 +185,9 @@ Status WriteRecordBatchMessage(int64_t length, int64_t body_length, const std::vector<FieldMetadata>& nodes, const std::vector<BufferMetadata>& buffers, std::shared_ptr<Buffer>* out); +Status WriteTensorMessage( + const Tensor& tensor, int64_t buffer_start_offset, std::shared_ptr<Buffer>* out); + Status WriteDictionaryMessage(int64_t id, int64_t length, int64_t body_length, const std::vector<FieldMetadata>& nodes, const std::vector<BufferMetadata>& buffers, std::shared_ptr<Buffer>* out); http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/reader.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index 28320d9..b47b773 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -33,6 +33,7 @@ #include "arrow/status.h" #include "arrow/table.h" #include "arrow/type.h" +#include "arrow/tensor.h" #include "arrow/util/logging.h" namespace arrow { @@ -186,28 +187,9 @@ class StreamReader::StreamReaderImpl { } Status ReadNextMessage(Message::Type expected_type, std::shared_ptr<Message>* message) { - std::shared_ptr<Buffer> buffer; - RETURN_NOT_OK(stream_->Read(sizeof(int32_t), &buffer)); - - if (buffer->size() != sizeof(int32_t)) { - *message = nullptr; - return Status::OK(); - } - - int32_t message_length = *reinterpret_cast<const int32_t*>(buffer->data()); - - if (message_length == 0) { - // Optional 0 EOS control message - *message = nullptr; - return Status::OK(); - } - - RETURN_NOT_OK(stream_->Read(message_length, &buffer)); - if (buffer->size() != message_length) { - return Status::IOError("Unexpected end of stream trying to read message"); - } + RETURN_NOT_OK(ReadMessage(stream_.get(), message)); - RETURN_NOT_OK(Message::Open(buffer, 0, message)); + if ((*message) == nullptr) { return Status::OK(); } if ((*message)->type() != expected_type) { std::stringstream ss; @@ -245,8 +227,7 @@ class StreamReader::StreamReaderImpl { std::shared_ptr<Message> message; RETURN_NOT_OK(ReadNextMessage(Message::SCHEMA, &message)); - SchemaMetadata schema_meta(message); - RETURN_NOT_OK(schema_meta.GetDictionaryTypes(&dictionary_types_)); + RETURN_NOT_OK(GetDictionaryTypes(message->header(), &dictionary_types_)); // TODO(wesm): In future, we may want to reconcile the ids in the stream with // those found in the schema @@ -255,7 +236,7 @@ class StreamReader::StreamReaderImpl { RETURN_NOT_OK(ReadNextDictionary()); } - return schema_meta.GetSchema(dictionary_memo_, &schema_); + return GetSchema(message->header(), dictionary_memo_, &schema_); } Status GetNextRecordBatch(std::shared_ptr<RecordBatch>* batch) { @@ -343,7 +324,6 @@ class FileReader::FileReaderImpl { // TODO(wesm): Verify the footer footer_ = flatbuf::GetFooter(footer_buffer_->data()); - schema_metadata_.reset(new SchemaMetadata(footer_->schema())); return Status::OK(); } @@ -372,8 +352,6 @@ class FileReader::FileReaderImpl { return FileBlockFromFlatbuffer(footer_->dictionaries()->Get(i)); } - const SchemaMetadata& schema_metadata() const { return *schema_metadata_; } - Status GetRecordBatch(int i, std::shared_ptr<RecordBatch>* batch) { DCHECK_GE(i, 0); DCHECK_LT(i, num_record_batches()); @@ -393,7 +371,7 @@ class FileReader::FileReaderImpl { } Status ReadSchema() { - RETURN_NOT_OK(schema_metadata_->GetDictionaryTypes(&dictionary_fields_)); + RETURN_NOT_OK(GetDictionaryTypes(footer_->schema(), &dictionary_fields_)); // Read all the dictionaries for (int i = 0; i < num_dictionaries(); ++i) { @@ -419,7 +397,7 @@ class FileReader::FileReaderImpl { } // Get the schema - return schema_metadata_->GetSchema(*dictionary_memo_, &schema_); + return GetSchema(footer_->schema(), *dictionary_memo_, &schema_); } Status Open(const std::shared_ptr<io::RandomAccessFile>& file, int64_t footer_offset) { @@ -441,7 +419,6 @@ class FileReader::FileReaderImpl { // Footer metadata std::shared_ptr<Buffer> footer_buffer_; const flatbuf::Footer* footer_; - std::unique_ptr<SchemaMetadata> schema_metadata_; DictionaryTypeMap dictionary_fields_; std::shared_ptr<DictionaryMemo> dictionary_memo_; @@ -485,26 +462,46 @@ Status FileReader::GetRecordBatch(int i, std::shared_ptr<RecordBatch>* batch) { return impl_->GetRecordBatch(i, batch); } -Status ReadRecordBatch(const std::shared_ptr<Schema>& schema, int64_t offset, - io::RandomAccessFile* file, std::shared_ptr<RecordBatch>* out) { +static Status ReadContiguousPayload(int64_t offset, io::RandomAccessFile* file, + std::shared_ptr<Message>* message, std::shared_ptr<Buffer>* payload) { std::shared_ptr<Buffer> buffer; RETURN_NOT_OK(file->Seek(offset)); + RETURN_NOT_OK(ReadMessage(file, message)); - RETURN_NOT_OK(file->Read(sizeof(int32_t), &buffer)); - int32_t flatbuffer_size = *reinterpret_cast<const int32_t*>(buffer->data()); - - std::shared_ptr<Message> message; - RETURN_NOT_OK(file->Read(flatbuffer_size, &buffer)); - RETURN_NOT_OK(Message::Open(buffer, 0, &message)); + if (*message == nullptr) { + return Status::Invalid("Unable to read metadata at offset"); + } // TODO(ARROW-388): The buffer offsets start at 0, so we must construct a // RandomAccessFile according to that frame of reference - std::shared_ptr<Buffer> buffer_payload; - RETURN_NOT_OK(file->Read(message->body_length(), &buffer_payload)); - io::BufferReader buffer_reader(buffer_payload); + RETURN_NOT_OK(file->Read((*message)->body_length(), payload)); + return Status::OK(); +} +Status ReadRecordBatch(const std::shared_ptr<Schema>& schema, int64_t offset, + io::RandomAccessFile* file, std::shared_ptr<RecordBatch>* out) { + std::shared_ptr<Buffer> payload; + std::shared_ptr<Message> message; + + RETURN_NOT_OK(ReadContiguousPayload(offset, file, &message, &payload)); + io::BufferReader buffer_reader(payload); return ReadRecordBatch(*message, schema, kMaxNestingDepth, &buffer_reader, out); } +Status ReadTensor( + int64_t offset, io::RandomAccessFile* file, std::shared_ptr<Tensor>* out) { + std::shared_ptr<Message> message; + std::shared_ptr<Buffer> data; + RETURN_NOT_OK(ReadContiguousPayload(offset, file, &message, &data)); + + std::shared_ptr<DataType> type; + std::vector<int64_t> shape; + std::vector<int64_t> strides; + std::vector<std::string> dim_names; + RETURN_NOT_OK( + GetTensorMetadata(message->header(), &type, &shape, &strides, &dim_names)); + return MakeTensor(type, data, shape, strides, dim_names, out); +} + } // namespace ipc } // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/reader.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/reader.h b/cpp/src/arrow/ipc/reader.h index 6d9e6ca..b62f052 100644 --- a/cpp/src/arrow/ipc/reader.h +++ b/cpp/src/arrow/ipc/reader.h @@ -17,8 +17,8 @@ // Implement Arrow file layout for IPC/RPC purposes and short-lived storage -#ifndef ARROW_IPC_FILE_H -#define ARROW_IPC_FILE_H +#ifndef ARROW_IPC_READER_H +#define ARROW_IPC_READER_H #include <cstdint> #include <memory> @@ -33,6 +33,7 @@ class Buffer; class RecordBatch; class Schema; class Status; +class Tensor; namespace io { @@ -43,18 +44,6 @@ class RandomAccessFile; namespace ipc { -// Generic read functionsh; does not copy data if the input supports zero copy reads - -Status ReadRecordBatch(const Message& metadata, const std::shared_ptr<Schema>& schema, - io::RandomAccessFile* file, std::shared_ptr<RecordBatch>* out); - -Status ReadRecordBatch(const Message& metadata, const std::shared_ptr<Schema>& schema, - int max_recursion_depth, io::RandomAccessFile* file, - std::shared_ptr<RecordBatch>* out); - -Status ReadDictionary(const Message& metadata, const DictionaryTypeMap& dictionary_types, - io::RandomAccessFile* file, std::shared_ptr<Array>* out); - class ARROW_EXPORT StreamReader { public: ~StreamReader(); @@ -118,11 +107,24 @@ class ARROW_EXPORT FileReader { std::unique_ptr<FileReaderImpl> impl_; }; +// Generic read functionsh; does not copy data if the input supports zero copy reads +Status ARROW_EXPORT ReadRecordBatch(const Message& metadata, + const std::shared_ptr<Schema>& schema, io::RandomAccessFile* file, + std::shared_ptr<RecordBatch>* out); + +Status ARROW_EXPORT ReadRecordBatch(const Message& metadata, + const std::shared_ptr<Schema>& schema, int max_recursion_depth, + io::RandomAccessFile* file, std::shared_ptr<RecordBatch>* out); + /// Read encapsulated message and RecordBatch Status ARROW_EXPORT ReadRecordBatch(const std::shared_ptr<Schema>& schema, int64_t offset, io::RandomAccessFile* file, std::shared_ptr<RecordBatch>* out); +/// EXPERIMENTAL: Read arrow::Tensor from a contiguous message +Status ARROW_EXPORT ReadTensor( + int64_t offset, io::RandomAccessFile* file, std::shared_ptr<Tensor>* out); + } // namespace ipc } // namespace arrow -#endif // ARROW_IPC_FILE_H +#endif // ARROW_IPC_READER_H http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/writer.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index 0a19f69..249ef20 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -34,6 +34,7 @@ #include "arrow/memory_pool.h" #include "arrow/status.h" #include "arrow/table.h" +#include "arrow/tensor.h" #include "arrow/type.h" #include "arrow/util/bit-util.h" #include "arrow/util/logging.h" @@ -143,46 +144,6 @@ class RecordBatchWriter : public ArrayVisitor { num_rows, body_length, field_nodes_, buffer_meta_, out); } - Status WriteMetadata(int64_t num_rows, int64_t body_length, io::OutputStream* dst, - int32_t* metadata_length) { - // Now that we have computed the locations of all of the buffers in shared - // memory, the data header can be converted to a flatbuffer and written out - // - // Note: The memory written here is prefixed by the size of the flatbuffer - // itself as an int32_t. - std::shared_ptr<Buffer> metadata_fb; - RETURN_NOT_OK(WriteMetadataMessage(num_rows, body_length, &metadata_fb)); - - // Need to write 4 bytes (metadata size), the metadata, plus padding to - // end on an 8-byte offset - int64_t start_offset; - RETURN_NOT_OK(dst->Tell(&start_offset)); - - int32_t padded_metadata_length = static_cast<int32_t>(metadata_fb->size()) + 4; - const int32_t remainder = - (padded_metadata_length + static_cast<int32_t>(start_offset)) % 8; - if (remainder != 0) { padded_metadata_length += 8 - remainder; } - - // The returned metadata size includes the length prefix, the flatbuffer, - // plus padding - *metadata_length = padded_metadata_length; - - // Write the flatbuffer size prefix including padding - int32_t flatbuffer_size = padded_metadata_length - 4; - RETURN_NOT_OK( - dst->Write(reinterpret_cast<const uint8_t*>(&flatbuffer_size), sizeof(int32_t))); - - // Write the flatbuffer - RETURN_NOT_OK(dst->Write(metadata_fb->data(), metadata_fb->size())); - - // Write any padding - int32_t padding = - padded_metadata_length - static_cast<int32_t>(metadata_fb->size()) - 4; - if (padding > 0) { RETURN_NOT_OK(dst->Write(kPaddingBytes, padding)); } - - return Status::OK(); - } - Status Write(const RecordBatch& batch, io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length) { RETURN_NOT_OK(Assemble(batch, body_length)); @@ -192,7 +153,14 @@ class RecordBatchWriter : public ArrayVisitor { RETURN_NOT_OK(dst->Tell(&start_position)); #endif - RETURN_NOT_OK(WriteMetadata(batch.num_rows(), *body_length, dst, metadata_length)); + // Now that we have computed the locations of all of the buffers in shared + // memory, the data header can be converted to a flatbuffer and written out + // + // Note: The memory written here is prefixed by the size of the flatbuffer + // itself as an int32_t. + std::shared_ptr<Buffer> metadata_fb; + RETURN_NOT_OK(WriteMetadataMessage(batch.num_rows(), *body_length, &metadata_fb)); + RETURN_NOT_OK(WriteMessage(*metadata_fb, dst, metadata_length)); #ifndef NDEBUG RETURN_NOT_OK(dst->Tell(¤t_position)); @@ -504,6 +472,28 @@ Status WriteRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset, return writer.Write(batch, dst, metadata_length, body_length); } +Status WriteLargeRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset, + io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length, + MemoryPool* pool) { + return WriteRecordBatch(batch, buffer_start_offset, dst, metadata_length, body_length, + pool, kMaxNestingDepth, true); +} + +Status WriteTensor(const Tensor& tensor, io::OutputStream* dst, int32_t* metadata_length, + int64_t* body_length) { + std::shared_ptr<Buffer> metadata; + RETURN_NOT_OK(WriteTensorMessage(tensor, 0, &metadata)); + RETURN_NOT_OK(WriteMessage(*metadata, dst, metadata_length)); + auto data = tensor.data(); + if (data) { + *body_length = data->size(); + return dst->Write(data->data(), *body_length); + } else { + *body_length = 0; + return Status::OK(); + } +} + Status WriteDictionary(int64_t dictionary_id, const std::shared_ptr<Array>& dictionary, int64_t buffer_start_offset, io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length, MemoryPool* pool) { @@ -736,12 +726,5 @@ Status FileWriter::Close() { return impl_->Close(); } -Status WriteLargeRecordBatch(const RecordBatch& batch, int64_t buffer_start_offset, - io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length, - MemoryPool* pool) { - return WriteRecordBatch(batch, buffer_start_offset, dst, metadata_length, body_length, - pool, kMaxNestingDepth, true); -} - } // namespace ipc } // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/ipc/writer.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/writer.h b/cpp/src/arrow/ipc/writer.h index c572157..8b2dc9c 100644 --- a/cpp/src/arrow/ipc/writer.h +++ b/cpp/src/arrow/ipc/writer.h @@ -17,8 +17,8 @@ // Implement Arrow streaming binary format -#ifndef ARROW_IPC_STREAM_H -#define ARROW_IPC_STREAM_H +#ifndef ARROW_IPC_WRITER_H +#define ARROW_IPC_WRITER_H #include <cstdint> #include <memory> @@ -36,6 +36,7 @@ class MemoryPool; class RecordBatch; class Schema; class Status; +class Tensor; namespace io { @@ -125,7 +126,12 @@ Status WriteLargeRecordBatch(const RecordBatch& batch, int64_t buffer_start_offs io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length, MemoryPool* pool); +/// EXPERIMENTAL: Write arrow::Tensor as a contiguous message +/// <metadata size><metadata><tensor data> +Status WriteTensor(const Tensor& tensor, io::OutputStream* dst, int32_t* metadata_length, + int64_t* body_length); + } // namespace ipc } // namespace arrow -#endif // ARROW_IPC_STREAM_H +#endif // ARROW_IPC_WRITER_H http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/tensor-test.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/tensor-test.cc b/cpp/src/arrow/tensor-test.cc index 99a9493..336905c 100644 --- a/cpp/src/arrow/tensor-test.cc +++ b/cpp/src/arrow/tensor-test.cc @@ -61,13 +61,36 @@ TEST(TestTensor, BasicCtors) { ASSERT_EQ(24, t1.size()); ASSERT_TRUE(t1.is_mutable()); - ASSERT_FALSE(t1.has_dim_names()); ASSERT_EQ(strides, t1.strides()); ASSERT_EQ(strides, t2.strides()); ASSERT_EQ("foo", t3.dim_name(0)); ASSERT_EQ("bar", t3.dim_name(1)); + ASSERT_EQ("", t1.dim_name(0)); + ASSERT_EQ("", t1.dim_name(1)); +} + +TEST(TestTensor, IsContiguous) { + const int64_t values = 24; + std::vector<int64_t> shape = {4, 6}; + std::vector<int64_t> strides = {48, 8}; + + using T = int64_t; + + std::shared_ptr<MutableBuffer> buffer; + ASSERT_OK(AllocateBuffer(default_memory_pool(), values * sizeof(T), &buffer)); + + std::vector<int64_t> c_strides = {48, 8}; + std::vector<int64_t> f_strides = {8, 32}; + std::vector<int64_t> noncontig_strides = {8, 8}; + Int64Tensor t1(buffer, shape, c_strides); + Int64Tensor t2(buffer, shape, f_strides); + Int64Tensor t3(buffer, shape, noncontig_strides); + + ASSERT_TRUE(t1.is_contiguous()); + ASSERT_TRUE(t2.is_contiguous()); + ASSERT_FALSE(t3.is_contiguous()); } } // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/tensor.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/tensor.cc b/cpp/src/arrow/tensor.cc index 7c4593f..9a8de51 100644 --- a/cpp/src/arrow/tensor.cc +++ b/cpp/src/arrow/tensor.cc @@ -27,14 +27,15 @@ #include "arrow/array.h" #include "arrow/buffer.h" +#include "arrow/compare.h" #include "arrow/type.h" #include "arrow/type_traits.h" #include "arrow/util/logging.h" namespace arrow { -void ComputeRowMajorStrides(const FixedWidthType& type, const std::vector<int64_t>& shape, - std::vector<int64_t>* strides) { +static void ComputeRowMajorStrides(const FixedWidthType& type, + const std::vector<int64_t>& shape, std::vector<int64_t>* strides) { int64_t remaining = type.bit_width() / 8; for (int64_t dimsize : shape) { remaining *= dimsize; @@ -46,6 +47,15 @@ void ComputeRowMajorStrides(const FixedWidthType& type, const std::vector<int64_ } } +static void ComputeColumnMajorStrides(const FixedWidthType& type, + const std::vector<int64_t>& shape, std::vector<int64_t>* strides) { + int64_t total = type.bit_width() / 8; + for (int64_t dimsize : shape) { + strides->push_back(total); + total *= dimsize; + } +} + /// Constructor with strides and dimension names Tensor::Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape, const std::vector<int64_t>& strides, @@ -66,14 +76,36 @@ Tensor::Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buff : Tensor(type, data, shape, {}, {}) {} const std::string& Tensor::dim_name(int i) const { - DCHECK_LT(i, static_cast<int>(dim_names_.size())); - return dim_names_[i]; + static const std::string kEmpty = ""; + if (dim_names_.size() == 0) { + return kEmpty; + } else { + DCHECK_LT(i, static_cast<int>(dim_names_.size())); + return dim_names_[i]; + } } int64_t Tensor::size() const { return std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies<int64_t>()); } +bool Tensor::is_contiguous() const { + std::vector<int64_t> c_strides; + std::vector<int64_t> f_strides; + + const auto& fw_type = static_cast<const FixedWidthType&>(*type_); + ComputeRowMajorStrides(fw_type, shape_, &c_strides); + ComputeColumnMajorStrides(fw_type, shape_, &f_strides); + return strides_ == c_strides || strides_ == f_strides; +} + +bool Tensor::Equals(const Tensor& other) const { + bool are_equal = false; + Status error = TensorEquals(*this, other, &are_equal); + if (!error.ok()) { DCHECK(false) << "Tensors not comparable: " << error.ToString(); } + return are_equal; +} + template <typename T> NumericTensor<T>::NumericTensor(const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape, const std::vector<int64_t>& strides, @@ -112,4 +144,31 @@ template class ARROW_TEMPLATE_EXPORT NumericTensor<HalfFloatType>; template class ARROW_TEMPLATE_EXPORT NumericTensor<FloatType>; template class ARROW_TEMPLATE_EXPORT NumericTensor<DoubleType>; +#define TENSOR_CASE(TYPE, TENSOR_TYPE) \ + case Type::TYPE: \ + *tensor = std::make_shared<TENSOR_TYPE>(data, shape, strides, dim_names); \ + break; + +Status ARROW_EXPORT MakeTensor(const std::shared_ptr<DataType>& type, + const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape, + const std::vector<int64_t>& strides, const std::vector<std::string>& dim_names, + std::shared_ptr<Tensor>* tensor) { + switch (type->type) { + TENSOR_CASE(INT8, Int8Tensor); + TENSOR_CASE(INT16, Int16Tensor); + TENSOR_CASE(INT32, Int32Tensor); + TENSOR_CASE(INT64, Int64Tensor); + TENSOR_CASE(UINT8, UInt8Tensor); + TENSOR_CASE(UINT16, UInt16Tensor); + TENSOR_CASE(UINT32, UInt32Tensor); + TENSOR_CASE(UINT64, UInt64Tensor); + TENSOR_CASE(HALF_FLOAT, HalfFloatTensor); + TENSOR_CASE(FLOAT, FloatTensor); + TENSOR_CASE(DOUBLE, DoubleTensor); + default: + return Status::NotImplemented(type->ToString()); + } + return Status::OK(); +} + } // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/tensor.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/tensor.h b/cpp/src/arrow/tensor.h index 7bee867..eeb5c3e 100644 --- a/cpp/src/arrow/tensor.h +++ b/cpp/src/arrow/tensor.h @@ -73,12 +73,15 @@ class ARROW_EXPORT Tensor { const std::vector<int64_t>& shape, const std::vector<int64_t>& strides, const std::vector<std::string>& dim_names); + std::shared_ptr<DataType> type() const { return type_; } std::shared_ptr<Buffer> data() const { return data_; } + const std::vector<int64_t>& shape() const { return shape_; } const std::vector<int64_t>& strides() const { return strides_; } + int ndim() const { return static_cast<int>(shape_.size()); } + const std::string& dim_name(int i) const; - bool has_dim_names() const { return shape_.size() > 0 && dim_names_.size() > 0; } /// Total number of value cells in the tensor int64_t size() const; @@ -86,13 +89,17 @@ class ARROW_EXPORT Tensor { /// Return true if the underlying data buffer is mutable bool is_mutable() const { return data_->is_mutable(); } + bool is_contiguous() const; + + Type::type type_enum() const { return type_->type; } + + bool Equals(const Tensor& other) const; + protected: Tensor() {} std::shared_ptr<DataType> type_; - std::shared_ptr<Buffer> data_; - std::vector<int64_t> shape_; std::vector<int64_t> strides_; @@ -126,6 +133,11 @@ class ARROW_EXPORT NumericTensor : public Tensor { value_type* mutable_raw_data_; }; +Status ARROW_EXPORT MakeTensor(const std::shared_ptr<DataType>& type, + const std::shared_ptr<Buffer>& data, const std::vector<int64_t>& shape, + const std::vector<int64_t>& strides, const std::vector<std::string>& dim_names, + std::shared_ptr<Tensor>* tensor); + // ---------------------------------------------------------------------- // extern templates and other details http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/type_traits.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 1270aee..b73d5a6 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -38,6 +38,7 @@ template <> struct TypeTraits<UInt8Type> { using ArrayType = UInt8Array; using BuilderType = UInt8Builder; + using TensorType = UInt8Tensor; static inline int64_t bytes_required(int64_t elements) { return elements; } constexpr static bool is_parameter_free = true; static inline std::shared_ptr<DataType> type_singleton() { return uint8(); } @@ -47,6 +48,7 @@ template <> struct TypeTraits<Int8Type> { using ArrayType = Int8Array; using BuilderType = Int8Builder; + using TensorType = Int8Tensor; static inline int64_t bytes_required(int64_t elements) { return elements; } constexpr static bool is_parameter_free = true; static inline std::shared_ptr<DataType> type_singleton() { return int8(); } @@ -56,6 +58,7 @@ template <> struct TypeTraits<UInt16Type> { using ArrayType = UInt16Array; using BuilderType = UInt16Builder; + using TensorType = UInt16Tensor; static inline int64_t bytes_required(int64_t elements) { return elements * sizeof(uint16_t); @@ -68,6 +71,7 @@ template <> struct TypeTraits<Int16Type> { using ArrayType = Int16Array; using BuilderType = Int16Builder; + using TensorType = Int16Tensor; static inline int64_t bytes_required(int64_t elements) { return elements * sizeof(int16_t); @@ -80,6 +84,7 @@ template <> struct TypeTraits<UInt32Type> { using ArrayType = UInt32Array; using BuilderType = UInt32Builder; + using TensorType = UInt32Tensor; static inline int64_t bytes_required(int64_t elements) { return elements * sizeof(uint32_t); @@ -92,6 +97,7 @@ template <> struct TypeTraits<Int32Type> { using ArrayType = Int32Array; using BuilderType = Int32Builder; + using TensorType = Int32Tensor; static inline int64_t bytes_required(int64_t elements) { return elements * sizeof(int32_t); @@ -104,6 +110,7 @@ template <> struct TypeTraits<UInt64Type> { using ArrayType = UInt64Array; using BuilderType = UInt64Builder; + using TensorType = UInt64Tensor; static inline int64_t bytes_required(int64_t elements) { return elements * sizeof(uint64_t); @@ -116,6 +123,7 @@ template <> struct TypeTraits<Int64Type> { using ArrayType = Int64Array; using BuilderType = Int64Builder; + using TensorType = Int64Tensor; static inline int64_t bytes_required(int64_t elements) { return elements * sizeof(int64_t); @@ -185,6 +193,7 @@ template <> struct TypeTraits<HalfFloatType> { using ArrayType = HalfFloatArray; using BuilderType = HalfFloatBuilder; + using TensorType = HalfFloatTensor; static inline int64_t bytes_required(int64_t elements) { return elements * sizeof(uint16_t); @@ -197,6 +206,7 @@ template <> struct TypeTraits<FloatType> { using ArrayType = FloatArray; using BuilderType = FloatBuilder; + using TensorType = FloatTensor; static inline int64_t bytes_required(int64_t elements) { return static_cast<int64_t>(elements * sizeof(float)); @@ -209,6 +219,7 @@ template <> struct TypeTraits<DoubleType> { using ArrayType = DoubleArray; using BuilderType = DoubleBuilder; + using TensorType = DoubleTensor; static inline int64_t bytes_required(int64_t elements) { return static_cast<int64_t>(elements * sizeof(double)); http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/cpp/src/arrow/visitor_inline.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/visitor_inline.h b/cpp/src/arrow/visitor_inline.h index 586b123..cbc4d5a 100644 --- a/cpp/src/arrow/visitor_inline.h +++ b/cpp/src/arrow/visitor_inline.h @@ -22,6 +22,7 @@ #include "arrow/array.h" #include "arrow/status.h" +#include "arrow/tensor.h" #include "arrow/type.h" namespace arrow { @@ -103,6 +104,31 @@ inline Status VisitArrayInline(const Array& array, VISITOR* visitor) { return Status::NotImplemented("Type not implemented"); } +#define TENSOR_VISIT_INLINE(TYPE_CLASS) \ + case TYPE_CLASS::type_id: \ + return visitor->Visit( \ + static_cast<const typename TypeTraits<TYPE_CLASS>::TensorType&>(array)); + +template <typename VISITOR> +inline Status VisitTensorInline(const Tensor& array, VISITOR* visitor) { + switch (array.type_enum()) { + TENSOR_VISIT_INLINE(Int8Type); + TENSOR_VISIT_INLINE(UInt8Type); + TENSOR_VISIT_INLINE(Int16Type); + TENSOR_VISIT_INLINE(UInt16Type); + TENSOR_VISIT_INLINE(Int32Type); + TENSOR_VISIT_INLINE(UInt32Type); + TENSOR_VISIT_INLINE(Int64Type); + TENSOR_VISIT_INLINE(UInt64Type); + TENSOR_VISIT_INLINE(HalfFloatType); + TENSOR_VISIT_INLINE(FloatType); + TENSOR_VISIT_INLINE(DoubleType); + default: + break; + } + return Status::NotImplemented("Type not implemented"); +} + } // namespace arrow #endif // ARROW_VISITOR_INLINE_H http://git-wip-us.apache.org/repos/asf/arrow/blob/957a0e67/format/Tensor.fbs ---------------------------------------------------------------------- diff --git a/format/Tensor.fbs b/format/Tensor.fbs index bc5b6d1..18b614c 100644 --- a/format/Tensor.fbs +++ b/format/Tensor.fbs @@ -32,16 +32,6 @@ table TensorDim { name: string; } -enum TensorOrder : byte { - /// Higher dimensions vary first when traversing data in byte-contiguous - /// order, aka "C order" - ROW_MAJOR, - - /// Lower dimensions vary first when traversing data in byte-contiguous - /// order, aka "Fortran order" - COLUMN_MAJOR -} - table Tensor { /// The type of data contained in a value cell. Currently only fixed-width /// value types are supported, no strings or nested types @@ -50,8 +40,8 @@ table Tensor { /// The dimensions of the tensor, optionally named shape: [TensorDim]; - /// The memory order of the tensor's data - order: TensorOrder; + /// Non-negative byte offsets to advance one value cell along each dimension + strides: [long]; /// The location and size of the tensor's data data: Buffer;