Repository: arrow Updated Branches: refs/heads/master 4733ee876 -> 23fe6ae02
ARROW-338: Implement visitor pattern for IPC loading/unloading This is a first cut at getting rid of the if-then-else statements and using the visitor pattern. This also has the benefit of forcing us to provide implementations should we add new types to Arrow. Author: Wes McKinney <wes.mckin...@twosigma.com> Closes #256 from wesm/ARROW-338 and squashes the following commits: 59bac66 [Wes McKinney] Fix accidental copy 17214c4 [Wes McKinney] Fix comment 6b00da4 [Wes McKinney] Implement visitor pattern for IPC loading/unloading Project: http://git-wip-us.apache.org/repos/asf/arrow/repo Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/23fe6ae0 Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/23fe6ae0 Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/23fe6ae0 Branch: refs/heads/master Commit: 23fe6ae02a6fa6ff912986c45079e25b3e5e4deb Parents: 4733ee8 Author: Wes McKinney <wes.mckin...@twosigma.com> Authored: Thu Dec 29 10:22:40 2016 +0100 Committer: Uwe L. Korn <uw...@xhochy.com> Committed: Thu Dec 29 10:22:40 2016 +0100 ---------------------------------------------------------------------- cpp/src/arrow/array.h | 1 + cpp/src/arrow/ipc/adapter.cc | 477 ++++++++++++++++++++++++-------------- cpp/src/arrow/type_fwd.h | 3 +- 3 files changed, 306 insertions(+), 175 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/arrow/blob/23fe6ae0/cpp/src/arrow/array.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h index 5cd56d6..6239ccc 100644 --- a/cpp/src/arrow/array.h +++ b/cpp/src/arrow/array.h @@ -471,6 +471,7 @@ extern template class ARROW_EXPORT NumericArray<FloatType>; extern template class ARROW_EXPORT NumericArray<DoubleType>; extern template class ARROW_EXPORT NumericArray<TimestampType>; extern template class ARROW_EXPORT NumericArray<DateType>; +extern template class ARROW_EXPORT NumericArray<TimeType>; #if defined(__GNUC__) && !defined(__clang__) #pragma GCC diagnostic pop http://git-wip-us.apache.org/repos/asf/arrow/blob/23fe6ae0/cpp/src/arrow/ipc/adapter.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/adapter.cc b/cpp/src/arrow/ipc/adapter.cc index f813c1d..ac4054b 100644 --- a/cpp/src/arrow/ipc/adapter.cc +++ b/cpp/src/arrow/ipc/adapter.cc @@ -34,6 +34,7 @@ #include "arrow/status.h" #include "arrow/table.h" #include "arrow/type.h" +#include "arrow/type_fwd.h" #include "arrow/util/bit-util.h" #include "arrow/util/logging.h" @@ -43,80 +44,34 @@ namespace flatbuf = org::apache::arrow::flatbuf; namespace ipc { -static bool IsPrimitive(const DataType* type) { - DCHECK(type != nullptr); - switch (type->type) { - // NA is null type or "no type", considered primitive for now - case Type::NA: - case Type::BOOL: - case Type::UINT8: - case Type::INT8: - case Type::UINT16: - case Type::INT16: - case Type::UINT32: - case Type::INT32: - case Type::UINT64: - case Type::INT64: - case Type::FLOAT: - case Type::DOUBLE: - return true; - default: - return false; - } -} - // ---------------------------------------------------------------------- // Record batch write path -Status VisitArray(const Array* arr, std::vector<flatbuf::FieldNode>* field_nodes, - std::vector<std::shared_ptr<Buffer>>* buffers, int max_recursion_depth) { - if (max_recursion_depth <= 0) { return Status::Invalid("Max recursion depth reached"); } - DCHECK(arr); - DCHECK(field_nodes); - // push back all common elements - field_nodes->push_back(flatbuf::FieldNode(arr->length(), arr->null_count())); - if (arr->null_count() > 0) { - buffers->push_back(arr->null_bitmap()); - } else { - // Push a dummy zero-length buffer, not to be copied - buffers->push_back(std::make_shared<Buffer>(nullptr, 0)); - } - - const DataType* arr_type = arr->type().get(); - if (IsPrimitive(arr_type)) { - const auto prim_arr = static_cast<const PrimitiveArray*>(arr); - buffers->push_back(prim_arr->data()); - } else if (arr->type_enum() == Type::STRING || arr->type_enum() == Type::BINARY) { - const auto binary_arr = static_cast<const BinaryArray*>(arr); - buffers->push_back(binary_arr->offsets()); - buffers->push_back(binary_arr->data()); - } else if (arr->type_enum() == Type::LIST) { - const auto list_arr = static_cast<const ListArray*>(arr); - buffers->push_back(list_arr->offsets()); - RETURN_NOT_OK(VisitArray( - list_arr->values().get(), field_nodes, buffers, max_recursion_depth - 1)); - } else if (arr->type_enum() == Type::STRUCT) { - const auto struct_arr = static_cast<const StructArray*>(arr); - for (auto& field : struct_arr->fields()) { - RETURN_NOT_OK( - VisitArray(field.get(), field_nodes, buffers, max_recursion_depth - 1)); - } - } else { - return Status::NotImplemented("Unrecognized type"); - } - return Status::OK(); -} - -class RecordBatchWriter { +class RecordBatchWriter : public ArrayVisitor { public: RecordBatchWriter(const std::vector<std::shared_ptr<Array>>& columns, int32_t num_rows, int64_t buffer_start_offset, int max_recursion_depth) - : columns_(&columns), + : columns_(columns), num_rows_(num_rows), - buffer_start_offset_(buffer_start_offset), - max_recursion_depth_(max_recursion_depth) {} + max_recursion_depth_(max_recursion_depth), + buffer_start_offset_(buffer_start_offset) {} - Status AssemblePayload(int64_t* body_length) { + Status VisitArray(const Array& arr) { + if (max_recursion_depth_ <= 0) { + return Status::Invalid("Max recursion depth reached"); + } + // push back all common elements + field_nodes_.push_back(flatbuf::FieldNode(arr.length(), arr.null_count())); + if (arr.null_count() > 0) { + buffers_.push_back(arr.null_bitmap()); + } else { + // Push a dummy zero-length buffer, not to be copied + buffers_.push_back(std::make_shared<Buffer>(nullptr, 0)); + } + return arr.Accept(this); + } + + Status Assemble(int64_t* body_length) { if (field_nodes_.size() > 0) { field_nodes_.clear(); buffer_meta_.clear(); @@ -124,9 +79,8 @@ class RecordBatchWriter { } // Perform depth-first traversal of the row-batch - for (size_t i = 0; i < columns_->size(); ++i) { - const Array* arr = (*columns_)[i].get(); - RETURN_NOT_OK(VisitArray(arr, &field_nodes_, &buffers_, max_recursion_depth_)); + for (size_t i = 0; i < columns_.size(); ++i) { + RETURN_NOT_OK(VisitArray(*columns_[i].get())); } // The position for the start of a buffer relative to the passed frame of @@ -199,7 +153,7 @@ class RecordBatchWriter { } Status Write(io::OutputStream* dst, int32_t* metadata_length, int64_t* body_length) { - RETURN_NOT_OK(AssemblePayload(body_length)); + RETURN_NOT_OK(Assemble(body_length)); #ifndef NDEBUG int64_t start_position, current_position; @@ -249,15 +203,92 @@ class RecordBatchWriter { } private: + Status Visit(const NullArray& array) override { return Status::NotImplemented("null"); } + + Status VisitPrimitive(const PrimitiveArray& array) { + buffers_.push_back(array.data()); + return Status::OK(); + } + + Status VisitBinary(const BinaryArray& array) { + buffers_.push_back(array.offsets()); + buffers_.push_back(array.data()); + return Status::OK(); + } + + Status Visit(const BooleanArray& array) override { return VisitPrimitive(array); } + + Status Visit(const Int8Array& array) override { return VisitPrimitive(array); } + + Status Visit(const Int16Array& array) override { return VisitPrimitive(array); } + + Status Visit(const Int32Array& array) override { return VisitPrimitive(array); } + + Status Visit(const Int64Array& array) override { return VisitPrimitive(array); } + + Status Visit(const UInt8Array& array) override { return VisitPrimitive(array); } + + Status Visit(const UInt16Array& array) override { return VisitPrimitive(array); } + + Status Visit(const UInt32Array& array) override { return VisitPrimitive(array); } + + Status Visit(const UInt64Array& array) override { return VisitPrimitive(array); } + + Status Visit(const HalfFloatArray& array) override { return VisitPrimitive(array); } + + Status Visit(const FloatArray& array) override { return VisitPrimitive(array); } + + Status Visit(const DoubleArray& array) override { return VisitPrimitive(array); } + + Status Visit(const StringArray& array) override { return VisitBinary(array); } + + Status Visit(const BinaryArray& array) override { return VisitBinary(array); } + + Status Visit(const DateArray& array) override { return VisitPrimitive(array); } + + Status Visit(const TimeArray& array) override { return VisitPrimitive(array); } + + Status Visit(const TimestampArray& array) override { return VisitPrimitive(array); } + + Status Visit(const IntervalArray& array) override { + return Status::NotImplemented("interval"); + } + + Status Visit(const DecimalArray& array) override { + return Status::NotImplemented("decimal"); + } + + Status Visit(const ListArray& array) override { + buffers_.push_back(array.offsets()); + --max_recursion_depth_; + RETURN_NOT_OK(VisitArray(*array.values().get())); + ++max_recursion_depth_; + return Status::OK(); + } + + Status Visit(const StructArray& array) override { + --max_recursion_depth_; + for (const auto& field : array.fields()) { + RETURN_NOT_OK(VisitArray(*field.get())); + } + ++max_recursion_depth_; + return Status::OK(); + } + + Status Visit(const UnionArray& array) override { + return Status::NotImplemented("union"); + } + // Do not copy this vector. Ownership must be retained elsewhere - const std::vector<std::shared_ptr<Array>>* columns_; + const std::vector<std::shared_ptr<Array>>& columns_; int32_t num_rows_; - int64_t buffer_start_offset_; std::vector<flatbuf::FieldNode> field_nodes_; std::vector<flatbuf::Buffer> buffer_meta_; std::vector<std::shared_ptr<Buffer>> buffers_; - int max_recursion_depth_; + + int64_t max_recursion_depth_; + int64_t buffer_start_offset_; }; Status WriteRecordBatch(const std::vector<std::shared_ptr<Array>>& columns, @@ -279,143 +310,241 @@ Status GetRecordBatchSize(const RecordBatch* batch, int64_t* size) { // ---------------------------------------------------------------------- // Record batch read path -class RecordBatchReader { - public: - RecordBatchReader(const std::shared_ptr<RecordBatchMetadata>& metadata, - const std::shared_ptr<Schema>& schema, int max_recursion_depth, - io::ReadableFileInterface* file) - : metadata_(metadata), - schema_(schema), - max_recursion_depth_(max_recursion_depth), - file_(file) { - num_buffers_ = metadata->num_buffers(); - num_flattened_fields_ = metadata->num_fields(); - } +struct RecordBatchContext { + const RecordBatchMetadata* metadata; + int buffer_index; + int field_index; + int max_recursion_depth; +}; - Status Read(std::shared_ptr<RecordBatch>* out) { - std::vector<std::shared_ptr<Array>> arrays(schema_->num_fields()); +// Traverse the flattened record batch metadata and reassemble the +// corresponding array containers +class ArrayLoader : public TypeVisitor { + public: + ArrayLoader( + const Field& field, RecordBatchContext* context, io::ReadableFileInterface* file) + : field_(field), context_(context), file_(file) {} - // The field_index and buffer_index are incremented in NextArray based on - // how much of the batch is "consumed" (through nested data reconstruction, - // for example) - field_index_ = 0; - buffer_index_ = 0; - for (int i = 0; i < schema_->num_fields(); ++i) { - const Field* field = schema_->field(i).get(); - RETURN_NOT_OK(NextArray(field, max_recursion_depth_, &arrays[i])); + Status Load(std::shared_ptr<Array>* out) { + if (context_->max_recursion_depth <= 0) { + return Status::Invalid("Max recursion depth reached"); } - *out = std::make_shared<RecordBatch>(schema_, metadata_->length(), arrays); + // Load the array + RETURN_NOT_OK(field_.type->Accept(this)); + + *out = std::move(result_); return Status::OK(); } private: - // Traverse the flattened record batch metadata and reassemble the - // corresponding array containers - Status NextArray( - const Field* field, int max_recursion_depth, std::shared_ptr<Array>* out) { - const TypePtr& type = field->type; - if (max_recursion_depth <= 0) { - return Status::Invalid("Max recursion depth reached"); + const Field& field_; + RecordBatchContext* context_; + io::ReadableFileInterface* file_; + + // Used in visitor pattern + std::shared_ptr<Array> result_; + + Status LoadChild(const Field& field, std::shared_ptr<Array>* out) { + ArrayLoader loader(field, context_, file_); + --context_->max_recursion_depth; + RETURN_NOT_OK(loader.Load(out)); + ++context_->max_recursion_depth; + return Status::OK(); + } + + Status GetBuffer(int buffer_index, std::shared_ptr<Buffer>* out) { + BufferMetadata metadata = context_->metadata->buffer(buffer_index); + + if (metadata.length == 0) { + *out = std::make_shared<Buffer>(nullptr, 0); + return Status::OK(); + } else { + return file_->ReadAt(metadata.offset, metadata.length, out); } + } + Status LoadCommon(FieldMetadata* field_meta, std::shared_ptr<Buffer>* null_bitmap) { // pop off a field - if (field_index_ >= num_flattened_fields_) { + if (context_->field_index >= context_->metadata->num_fields()) { return Status::Invalid("Ran out of field metadata, likely malformed"); } // This only contains the length and null count, which we need to figure // out what to do with the buffers. For example, if null_count == 0, then // we can skip that buffer without reading from shared memory - FieldMetadata field_meta = metadata_->field(field_index_++); + *field_meta = context_->metadata->field(context_->field_index++); // extract null_bitmap which is common to all arrays + if (field_meta->null_count == 0) { + *null_bitmap = nullptr; + } else { + RETURN_NOT_OK(GetBuffer(context_->buffer_index, null_bitmap)); + } + context_->buffer_index++; + return Status::OK(); + } + + Status LoadPrimitive(const DataType& type) { + FieldMetadata field_meta; std::shared_ptr<Buffer> null_bitmap; - if (field_meta.null_count == 0) { - ++buffer_index_; + RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); + + std::shared_ptr<Buffer> data; + if (field_meta.length > 0) { + RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &data)); } else { - RETURN_NOT_OK(GetBuffer(buffer_index_++, &null_bitmap)); + context_->buffer_index++; + data.reset(new Buffer(nullptr, 0)); } + return MakePrimitiveArray(field_.type, field_meta.length, data, field_meta.null_count, + null_bitmap, &result_); + } - if (IsPrimitive(type.get())) { - std::shared_ptr<Buffer> data; - if (field_meta.length > 0) { - RETURN_NOT_OK(GetBuffer(buffer_index_++, &data)); - } else { - buffer_index_++; - data.reset(new Buffer(nullptr, 0)); - } - return MakePrimitiveArray( - type, field_meta.length, data, field_meta.null_count, null_bitmap, out); - } else if (type->type == Type::STRING || type->type == Type::BINARY) { - std::shared_ptr<Buffer> offsets; - std::shared_ptr<Buffer> values; - RETURN_NOT_OK(GetBuffer(buffer_index_++, &offsets)); - RETURN_NOT_OK(GetBuffer(buffer_index_++, &values)); - - if (type->type == Type::STRING) { - *out = std::make_shared<StringArray>( - field_meta.length, offsets, values, field_meta.null_count, null_bitmap); - } else { - *out = std::make_shared<BinaryArray>( - field_meta.length, offsets, values, field_meta.null_count, null_bitmap); - } - return Status::OK(); - } else if (type->type == Type::LIST) { - std::shared_ptr<Buffer> offsets; - RETURN_NOT_OK(GetBuffer(buffer_index_++, &offsets)); - const int num_children = type->num_children(); - if (num_children != 1) { - std::stringstream ss; - ss << "Field: " << field->ToString() - << " has wrong number of children:" << num_children; - return Status::Invalid(ss.str()); - } - std::shared_ptr<Array> values_array; - RETURN_NOT_OK( - NextArray(type->child(0).get(), max_recursion_depth - 1, &values_array)); - *out = std::make_shared<ListArray>(type, field_meta.length, offsets, values_array, - field_meta.null_count, null_bitmap); - return Status::OK(); - } else if (type->type == Type::STRUCT) { - const int num_children = type->num_children(); - std::vector<ArrayPtr> fields; - fields.reserve(num_children); - for (int child_idx = 0; child_idx < num_children; ++child_idx) { - std::shared_ptr<Array> field_array; - RETURN_NOT_OK(NextArray( - type->child(child_idx).get(), max_recursion_depth - 1, &field_array)); - fields.push_back(field_array); - } - out->reset(new StructArray( - type, field_meta.length, fields, field_meta.null_count, null_bitmap)); - return Status::OK(); + template <typename CONTAINER> + Status LoadBinary() { + FieldMetadata field_meta; + std::shared_ptr<Buffer> null_bitmap; + RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); + + std::shared_ptr<Buffer> offsets; + std::shared_ptr<Buffer> values; + RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &offsets)); + RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &values)); + + result_ = std::make_shared<CONTAINER>( + field_meta.length, offsets, values, field_meta.null_count, null_bitmap); + return Status::OK(); + } + + Status Visit(const NullType& type) override { return Status::NotImplemented("null"); } + + Status Visit(const BooleanType& type) override { return LoadPrimitive(type); } + + Status Visit(const Int8Type& type) override { return LoadPrimitive(type); } + + Status Visit(const Int16Type& type) override { return LoadPrimitive(type); } + + Status Visit(const Int32Type& type) override { return LoadPrimitive(type); } + + Status Visit(const Int64Type& type) override { return LoadPrimitive(type); } + + Status Visit(const UInt8Type& type) override { return LoadPrimitive(type); } + + Status Visit(const UInt16Type& type) override { return LoadPrimitive(type); } + + Status Visit(const UInt32Type& type) override { return LoadPrimitive(type); } + + Status Visit(const UInt64Type& type) override { return LoadPrimitive(type); } + + Status Visit(const HalfFloatType& type) override { return LoadPrimitive(type); } + + Status Visit(const FloatType& type) override { return LoadPrimitive(type); } + + Status Visit(const DoubleType& type) override { return LoadPrimitive(type); } + + Status Visit(const StringType& type) override { return LoadBinary<StringArray>(); } + + Status Visit(const BinaryType& type) override { return LoadBinary<BinaryArray>(); } + + Status Visit(const DateType& type) override { return LoadPrimitive(type); } + + Status Visit(const TimeType& type) override { return LoadPrimitive(type); } + + Status Visit(const TimestampType& type) override { return LoadPrimitive(type); } + + Status Visit(const IntervalType& type) override { + return Status::NotImplemented(type.ToString()); + } + + Status Visit(const DecimalType& type) override { + return Status::NotImplemented(type.ToString()); + } + + Status Visit(const ListType& type) override { + FieldMetadata field_meta; + std::shared_ptr<Buffer> null_bitmap; + std::shared_ptr<Buffer> offsets; + + RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); + RETURN_NOT_OK(GetBuffer(context_->buffer_index++, &offsets)); + + const int num_children = type.num_children(); + if (num_children != 1) { + std::stringstream ss; + ss << "Wrong number of children: " << num_children; + return Status::Invalid(ss.str()); } + std::shared_ptr<Array> values_array; - return Status::NotImplemented("Non-primitive types not complete yet"); + RETURN_NOT_OK(LoadChild(*type.child(0).get(), &values_array)); + + result_ = std::make_shared<ListArray>(field_.type, field_meta.length, offsets, + values_array, field_meta.null_count, null_bitmap); + return Status::OK(); } - Status GetBuffer(int buffer_index, std::shared_ptr<Buffer>* out) { - BufferMetadata metadata = metadata_->buffer(buffer_index); + Status Visit(const StructType& type) override { + FieldMetadata field_meta; + std::shared_ptr<Buffer> null_bitmap; + RETURN_NOT_OK(LoadCommon(&field_meta, &null_bitmap)); - if (metadata.length == 0) { - *out = std::make_shared<Buffer>(nullptr, 0); - return Status::OK(); - } else { - return file_->ReadAt(metadata.offset, metadata.length, out); + const int num_children = type.num_children(); + std::vector<ArrayPtr> fields; + fields.reserve(num_children); + + for (int child_idx = 0; child_idx < num_children; ++child_idx) { + std::shared_ptr<Array> field_array; + RETURN_NOT_OK(LoadChild(*type.child(child_idx).get(), &field_array)); + fields.emplace_back(field_array); + } + + result_ = std::make_shared<StructArray>( + field_.type, field_meta.length, fields, field_meta.null_count, null_bitmap); + return Status::OK(); + } + + Status Visit(const UnionType& type) override { + return Status::NotImplemented(type.ToString()); + } +}; + +class RecordBatchReader { + public: + RecordBatchReader(const std::shared_ptr<RecordBatchMetadata>& metadata, + const std::shared_ptr<Schema>& schema, int max_recursion_depth, + io::ReadableFileInterface* file) + : metadata_(metadata), + schema_(schema), + max_recursion_depth_(max_recursion_depth), + file_(file) {} + + Status Read(std::shared_ptr<RecordBatch>* out) { + std::vector<std::shared_ptr<Array>> arrays(schema_->num_fields()); + + // The field_index and buffer_index are incremented in the ArrayLoader + // based on how much of the batch is "consumed" (through nested data + // reconstruction, for example) + context_.metadata = metadata_.get(); + context_.field_index = 0; + context_.buffer_index = 0; + context_.max_recursion_depth = max_recursion_depth_; + + for (int i = 0; i < schema_->num_fields(); ++i) { + ArrayLoader loader(*schema_->field(i).get(), &context_, file_); + RETURN_NOT_OK(loader.Load(&arrays[i])); } + + *out = std::make_shared<RecordBatch>(schema_, metadata_->length(), arrays); + return Status::OK(); } private: + RecordBatchContext context_; std::shared_ptr<RecordBatchMetadata> metadata_; std::shared_ptr<Schema> schema_; int max_recursion_depth_; io::ReadableFileInterface* file_; - - int field_index_; - int buffer_index_; - int num_buffers_; - int num_flattened_fields_; }; Status ReadRecordBatchMetadata(int64_t offset, int32_t metadata_length, http://git-wip-us.apache.org/repos/asf/arrow/blob/23fe6ae0/cpp/src/arrow/type_fwd.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index a9db32d..a14c535 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -91,7 +91,8 @@ using DateArray = NumericArray<DateType>; using DateBuilder = NumericBuilder<DateType>; struct TimeType; -class TimeArray; +using TimeArray = NumericArray<TimeType>; +using TimeBuilder = NumericBuilder<TimeType>; struct TimestampType; using TimestampArray = NumericArray<TimestampType>;