Repository: arrow Updated Branches: refs/heads/master 2a2b1094e -> 25ba44c32
ARROW-460: [C++] JSON read/write for dictionaries cc @BryanCutler; this ended up being rather tedious. As one uncertainty: I wasn't sure what to write for the `typeLayout` field for dictionary fields. We don't use this at all in C++, so let me know what you need in Java (if anything) Author: Wes McKinney <wes.mckin...@twosigma.com> Closes #750 from wesm/ARROW-460 and squashes the following commits: 36d6019e [Wes McKinney] Put dictionaries in top level of json object 92e2a95f [Wes McKinney] Some debugging help, get test suite passing 15b8fc6c [Wes McKinney] Schema read correct, but array reads incorrect 6581bd6d [Wes McKinney] Cleaning up JSON ArrayReader to use inline visitors. Progress 44115663 [Wes McKinney] Misc fixes, dictionary schema roundtrip not complete yet 59580222 [Wes McKinney] Draft JSON roundtrip with dictionaries, not yet tested Project: http://git-wip-us.apache.org/repos/asf/arrow/repo Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/25ba44c3 Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/25ba44c3 Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/25ba44c3 Branch: refs/heads/master Commit: 25ba44c3287de968ba22fc21577fe4639d81c4dc Parents: 2a2b109 Author: Wes McKinney <wes.mckin...@twosigma.com> Authored: Tue Jun 13 11:54:18 2017 -0400 Committer: Wes McKinney <wes.mckin...@twosigma.com> Committed: Tue Jun 13 11:54:18 2017 -0400 ---------------------------------------------------------------------- cpp/src/arrow/ipc/ipc-json-test.cc | 23 +- cpp/src/arrow/ipc/json-internal.cc | 723 +++++++++++++++++++++----------- cpp/src/arrow/ipc/json-internal.h | 30 +- cpp/src/arrow/ipc/json.cc | 55 +-- cpp/src/arrow/ipc/json.h | 3 - cpp/src/arrow/ipc/metadata.h | 3 + cpp/src/arrow/ipc/test-common.h | 13 +- cpp/src/arrow/ipc/writer.cc | 13 +- 8 files changed, 535 insertions(+), 328 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/arrow/blob/25ba44c3/cpp/src/arrow/ipc/ipc-json-test.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/ipc-json-test.cc b/cpp/src/arrow/ipc/ipc-json-test.cc index 9cf6a88..e06af9f 100644 --- a/cpp/src/arrow/ipc/ipc-json-test.cc +++ b/cpp/src/arrow/ipc/ipc-json-test.cc @@ -39,18 +39,25 @@ namespace arrow { namespace ipc { +namespace json { void TestSchemaRoundTrip(const Schema& schema) { rj::StringBuffer sb; rj::Writer<rj::StringBuffer> writer(sb); - ASSERT_OK(WriteJsonSchema(schema, &writer)); + writer.StartObject(); + ASSERT_OK(internal::WriteSchema(schema, &writer)); + writer.EndObject(); + + std::string json_schema = sb.GetString(); rj::Document d; - d.Parse(sb.GetString()); + d.Parse(json_schema); std::shared_ptr<Schema> out; - ASSERT_OK(ReadJsonSchema(d, &out)); + if (!internal::ReadSchema(d, default_memory_pool(), &out).ok()) { + FAIL() << "Unable to read JSON schema: " << json_schema; + } if (!schema.Equals(*out)) { FAIL() << "In schema: " << schema.ToString() << "\nOut schema: " << out->ToString(); @@ -63,7 +70,7 @@ void TestArrayRoundTrip(const Array& array) { rj::StringBuffer sb; rj::Writer<rj::StringBuffer> writer(sb); - ASSERT_OK(WriteJsonArray(name, array, &writer)); + ASSERT_OK(internal::WriteArray(name, array, &writer)); std::string array_as_json = sb.GetString(); @@ -73,7 +80,7 @@ void TestArrayRoundTrip(const Array& array) { if (d.HasParseError()) { FAIL() << "JSON parsing failed"; } std::shared_ptr<Array> out; - ASSERT_OK(ReadJsonArray(default_memory_pool(), d, array.type(), &out)); + ASSERT_OK(internal::ReadArray(default_memory_pool(), d, array.type(), &out)); // std::cout << array_as_json << std::endl; CompareArraysDetailed(0, *out, array); @@ -355,7 +362,8 @@ TEST(TestJsonFileReadWrite, MinimalFormatExample) { #define BATCH_CASES() \ ::testing::Values(&MakeIntRecordBatch, &MakeListRecordBatch, &MakeNonNullRecordBatch, \ &MakeZeroLengthRecordBatch, &MakeDeeplyNestedList, &MakeStringTypesRecordBatch, \ - &MakeStruct, &MakeUnion, &MakeDates, &MakeTimestamps, &MakeTimes, &MakeFWBinary); + &MakeStruct, &MakeUnion, &MakeDates, &MakeTimestamps, &MakeTimes, &MakeFWBinary, \ + &MakeDictionary); class TestJsonRoundTrip : public ::testing::TestWithParam<MakeRecordBatch*> { public: @@ -364,6 +372,8 @@ class TestJsonRoundTrip : public ::testing::TestWithParam<MakeRecordBatch*> { }; void CheckRoundtrip(const RecordBatch& batch) { + TestSchemaRoundTrip(*batch.schema()); + std::unique_ptr<JsonWriter> writer; ASSERT_OK(JsonWriter::Open(batch.schema(), &writer)); ASSERT_OK(writer->WriteRecordBatch(batch)); @@ -392,5 +402,6 @@ TEST_P(TestJsonRoundTrip, RoundTrip) { INSTANTIATE_TEST_CASE_P(TestJsonRoundTrip, TestJsonRoundTrip, BATCH_CASES()); +} // namespace json } // namespace ipc } // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/25ba44c3/cpp/src/arrow/ipc/json-internal.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/json-internal.cc b/cpp/src/arrow/ipc/json-internal.cc index 2ab3acb..bf2c194 100644 --- a/cpp/src/arrow/ipc/json-internal.cc +++ b/cpp/src/arrow/ipc/json-internal.cc @@ -32,8 +32,10 @@ #include "arrow/array.h" #include "arrow/builder.h" +#include "arrow/ipc/metadata.h" #include "arrow/memory_pool.h" #include "arrow/status.h" +#include "arrow/table.h" #include "arrow/type.h" #include "arrow/type_traits.h" #include "arrow/util/bit-util.h" @@ -43,9 +45,8 @@ namespace arrow { namespace ipc { - -using RjArray = rj::Value::ConstArray; -using RjObject = rj::Value::ConstObject; +namespace json { +namespace internal { static std::string GetBufferTypeName(BufferType type) { switch (type) { @@ -93,20 +94,67 @@ static std::string GetTimeUnitName(TimeUnit::type unit) { return "UNKNOWN"; } -class JsonSchemaWriter { +class SchemaWriter { public: - explicit JsonSchemaWriter(const Schema& schema, RjWriter* writer) + explicit SchemaWriter(const Schema& schema, RjWriter* writer) : schema_(schema), writer_(writer) {} Status Write() { + writer_->Key("schema"); writer_->StartObject(); writer_->Key("fields"); writer_->StartArray(); for (const std::shared_ptr<Field>& field : schema_.fields()) { - RETURN_NOT_OK(VisitField(*field.get())); + RETURN_NOT_OK(VisitField(*field)); } writer_->EndArray(); writer_->EndObject(); + + // Write dictionaries, if any + if (dictionary_memo_.size() > 0) { + writer_->Key("dictionaries"); + writer_->StartArray(); + for (const auto& entry : dictionary_memo_.id_to_dictionary()) { + RETURN_NOT_OK(WriteDictionary(entry.first, entry.second)); + } + writer_->EndArray(); + } + return Status::OK(); + } + + Status WriteDictionary(int64_t id, const std::shared_ptr<Array>& dictionary) { + writer_->StartObject(); + writer_->Key("id"); + writer_->Int(static_cast<int32_t>(id)); + writer_->Key("data"); + + // Make a dummy record batch. A bit tedious as we have to make a schema + auto schema = std::shared_ptr<Schema>( + new Schema({arrow::field("dictionary", dictionary->type())})); + RecordBatch batch(schema, dictionary->length(), {dictionary}); + RETURN_NOT_OK(WriteRecordBatch(batch, writer_)); + writer_->EndObject(); + return Status::OK(); + } + + Status WriteDictionaryMetadata(const DictionaryType& type) { + int64_t dictionary_id = dictionary_memo_.GetId(type.dictionary()); + writer_->Key("dictionary"); + + // Emulate DictionaryEncoding from Schema.fbs + writer_->StartObject(); + writer_->Key("id"); + writer_->Int(static_cast<int32_t>(dictionary_id)); + writer_->Key("indexType"); + + writer_->StartObject(); + RETURN_NOT_OK(VisitType(*type.index_type())); + writer_->EndObject(); + + writer_->Key("isOrdered"); + writer_->Bool(type.ordered()); + writer_->EndObject(); + return Status::OK(); } @@ -119,18 +167,33 @@ class JsonSchemaWriter { writer_->Key("nullable"); writer_->Bool(field.nullable()); + const DataType& type = *field.type(); + // Visit the type - RETURN_NOT_OK(VisitTypeInline(*field.type(), this)); + writer_->Key("type"); + writer_->StartObject(); + RETURN_NOT_OK(VisitType(type)); + writer_->EndObject(); + + if (type.id() == Type::DICTIONARY) { + const auto& dict_type = static_cast<const DictionaryType&>(type); + RETURN_NOT_OK(WriteDictionaryMetadata(dict_type)); + + const DataType& dictionary_type = *dict_type.dictionary()->type(); + const DataType& index_type = *dict_type.index_type(); + RETURN_NOT_OK(WriteChildren(dictionary_type.children())); + WriteBufferLayout(index_type.GetBufferLayout()); + } else { + RETURN_NOT_OK(WriteChildren(type.children())); + WriteBufferLayout(type.GetBufferLayout()); + } + writer_->EndObject(); return Status::OK(); } - void SetNoChildren() { - writer_->Key("children"); - writer_->StartArray(); - writer_->EndArray(); - } + Status VisitType(const DataType& type); template <typename T> typename std::enable_if<std::is_base_of<NoExtraMeta, T>::value || @@ -227,27 +290,20 @@ class JsonSchemaWriter { template <typename T> void WriteName(const std::string& typeclass, const T& type) { - writer_->Key("type"); - writer_->StartObject(); writer_->Key("name"); writer_->String(typeclass); WriteTypeMetadata(type); - writer_->EndObject(); } template <typename T> Status WritePrimitive(const std::string& typeclass, const T& type) { WriteName(typeclass, type); - SetNoChildren(); - WriteBufferLayout(type.GetBufferLayout()); return Status::OK(); } template <typename T> Status WriteVarBytes(const std::string& typeclass, const T& type) { WriteName(typeclass, type); - SetNoChildren(); - WriteBufferLayout(type.GetBufferLayout()); return Status::OK(); } @@ -275,16 +331,14 @@ class JsonSchemaWriter { writer_->Key("children"); writer_->StartArray(); for (const std::shared_ptr<Field>& field : children) { - RETURN_NOT_OK(VisitField(*field.get())); + RETURN_NOT_OK(VisitField(*field)); } writer_->EndArray(); return Status::OK(); } Status Visit(const NullType& type) { return WritePrimitive("null", type); } - Status Visit(const BooleanType& type) { return WritePrimitive("bool", type); } - Status Visit(const Integer& type) { return WritePrimitive("int", type); } Status Visit(const FloatingPoint& type) { @@ -292,60 +346,57 @@ class JsonSchemaWriter { } Status Visit(const DateType& type) { return WritePrimitive("date", type); } - Status Visit(const TimeType& type) { return WritePrimitive("time", type); } - Status Visit(const StringType& type) { return WriteVarBytes("utf8", type); } - Status Visit(const BinaryType& type) { return WriteVarBytes("binary", type); } - Status Visit(const FixedSizeBinaryType& type) { return WritePrimitive("fixedsizebinary", type); } Status Visit(const TimestampType& type) { return WritePrimitive("timestamp", type); } - Status Visit(const IntervalType& type) { return WritePrimitive("interval", type); } Status Visit(const ListType& type) { WriteName("list", type); - RETURN_NOT_OK(WriteChildren(type.children())); - WriteBufferLayout(type.GetBufferLayout()); return Status::OK(); } Status Visit(const StructType& type) { WriteName("struct", type); - WriteChildren(type.children()); - WriteBufferLayout(type.GetBufferLayout()); return Status::OK(); } Status Visit(const UnionType& type) { WriteName("union", type); - WriteChildren(type.children()); - WriteBufferLayout(type.GetBufferLayout()); return Status::OK(); } Status Visit(const DecimalType& type) { return Status::NotImplemented("decimal"); } Status Visit(const DictionaryType& type) { - return Status::NotImplemented("dictionary"); + return VisitType(*type.dictionary()->type()); } private: + DictionaryMemo dictionary_memo_; + const Schema& schema_; RjWriter* writer_; }; -class JsonArrayWriter { +Status SchemaWriter::VisitType(const DataType& type) { + return VisitTypeInline(type, this); +} + +class ArrayWriter { public: - JsonArrayWriter(const std::string& name, const Array& array, RjWriter* writer) + ArrayWriter(const std::string& name, const Array& array, RjWriter* writer) : name_(name), array_(array), writer_(writer) {} Status Write() { return VisitArray(name_, array_); } + Status VisitArrayValues(const Array& arr) { return VisitArrayInline(arr, this); } + Status VisitArray(const std::string& name, const Array& arr) { writer_->StartObject(); writer_->Key("name"); @@ -354,7 +405,7 @@ class JsonArrayWriter { writer_->Key("count"); writer_->Int(static_cast<int32_t>(arr.length())); - RETURN_NOT_OK(VisitArrayInline(arr, this)); + RETURN_NOT_OK(VisitArrayValues(arr)); writer_->EndObject(); return Status::OK(); @@ -461,7 +512,7 @@ class JsonArrayWriter { writer_->Key("children"); writer_->StartArray(); for (size_t i = 0; i < fields.size(); ++i) { - RETURN_NOT_OK(VisitArray(fields[i]->name(), *arrays[i].get())); + RETURN_NOT_OK(VisitArray(fields[i]->name(), *arrays[i])); } writer_->EndArray(); return Status::OK(); @@ -493,30 +544,32 @@ class JsonArrayWriter { Status Visit(const DecimalArray& array) { return Status::NotImplemented("decimal"); } - Status Visit(const DictionaryArray& array) { return Status::NotImplemented("decimal"); } + Status Visit(const DictionaryArray& array) { + return VisitArrayValues(*array.indices()); + } Status Visit(const ListArray& array) { WriteValidityField(array); WriteIntegerField("OFFSET", array.raw_value_offsets(), array.length() + 1); - auto type = static_cast<const ListType*>(array.type().get()); - return WriteChildren(type->children(), {array.values()}); + const auto& type = static_cast<const ListType&>(*array.type()); + return WriteChildren(type.children(), {array.values()}); } Status Visit(const StructArray& array) { WriteValidityField(array); - auto type = static_cast<const StructType*>(array.type().get()); - return WriteChildren(type->children(), array.fields()); + const auto& type = static_cast<const StructType&>(*array.type()); + return WriteChildren(type.children(), array.fields()); } Status Visit(const UnionArray& array) { WriteValidityField(array); - auto type = static_cast<const UnionType*>(array.type().get()); + const auto& type = static_cast<const UnionType&>(*array.type()); WriteIntegerField("TYPE_ID", array.raw_type_ids(), array.length()); - if (type->mode() == UnionMode::DENSE) { + if (type.mode() == UnionMode::DENSE) { WriteIntegerField("OFFSET", array.raw_value_offsets(), array.length()); } - return WriteChildren(type->children(), array.children()); + return WriteChildren(type.children(), array.children()); } private: @@ -525,16 +578,38 @@ class JsonArrayWriter { RjWriter* writer_; }; +static Status GetObjectInt(const RjObject& obj, const std::string& key, int* out) { + const auto& it = obj.FindMember(key); + RETURN_NOT_INT(key, it, obj); + *out = it->value.GetInt(); + return Status::OK(); +} + +static Status GetObjectBool(const RjObject& obj, const std::string& key, bool* out) { + const auto& it = obj.FindMember(key); + RETURN_NOT_BOOL(key, it, obj); + *out = it->value.GetBool(); + return Status::OK(); +} + +static Status GetObjectString( + const RjObject& obj, const std::string& key, std::string* out) { + const auto& it = obj.FindMember(key); + RETURN_NOT_STRING(key, it, obj); + *out = it->value.GetString(); + return Status::OK(); +} + static Status GetInteger( const rj::Value::ConstObject& json_type, std::shared_ptr<DataType>* type) { - const auto& json_bit_width = json_type.FindMember("bitWidth"); - RETURN_NOT_INT("bitWidth", json_bit_width, json_type); + const auto& it_bit_width = json_type.FindMember("bitWidth"); + RETURN_NOT_INT("bitWidth", it_bit_width, json_type); - const auto& json_is_signed = json_type.FindMember("isSigned"); - RETURN_NOT_BOOL("isSigned", json_is_signed, json_type); + const auto& it_is_signed = json_type.FindMember("isSigned"); + RETURN_NOT_BOOL("isSigned", it_is_signed, json_type); - bool is_signed = json_is_signed->value.GetBool(); - int bit_width = json_bit_width->value.GetInt(); + bool is_signed = it_is_signed->value.GetBool(); + int bit_width = it_bit_width->value.GetInt(); switch (bit_width) { case 8: @@ -559,10 +634,10 @@ static Status GetInteger( static Status GetFloatingPoint( const RjObject& json_type, std::shared_ptr<DataType>* type) { - const auto& json_precision = json_type.FindMember("precision"); - RETURN_NOT_STRING("precision", json_precision, json_type); + const auto& it_precision = json_type.FindMember("precision"); + RETURN_NOT_STRING("precision", it_precision, json_type); - std::string precision = json_precision->value.GetString(); + std::string precision = it_precision->value.GetString(); if (precision == "DOUBLE") { *type = float64(); @@ -580,19 +655,19 @@ static Status GetFloatingPoint( static Status GetFixedSizeBinary( const RjObject& json_type, std::shared_ptr<DataType>* type) { - const auto& json_byte_width = json_type.FindMember("byteWidth"); - RETURN_NOT_INT("byteWidth", json_byte_width, json_type); + const auto& it_byte_width = json_type.FindMember("byteWidth"); + RETURN_NOT_INT("byteWidth", it_byte_width, json_type); - int32_t byte_width = json_byte_width->value.GetInt(); + int32_t byte_width = it_byte_width->value.GetInt(); *type = fixed_size_binary(byte_width); return Status::OK(); } static Status GetDate(const RjObject& json_type, std::shared_ptr<DataType>* type) { - const auto& json_unit = json_type.FindMember("unit"); - RETURN_NOT_STRING("unit", json_unit, json_type); + const auto& it_unit = json_type.FindMember("unit"); + RETURN_NOT_STRING("unit", it_unit, json_type); - std::string unit_str = json_unit->value.GetString(); + std::string unit_str = it_unit->value.GetString(); if (unit_str == "DAY") { *type = date32(); @@ -607,13 +682,13 @@ static Status GetDate(const RjObject& json_type, std::shared_ptr<DataType>* type } static Status GetTime(const RjObject& json_type, std::shared_ptr<DataType>* type) { - const auto& json_unit = json_type.FindMember("unit"); - RETURN_NOT_STRING("unit", json_unit, json_type); + const auto& it_unit = json_type.FindMember("unit"); + RETURN_NOT_STRING("unit", it_unit, json_type); - const auto& json_bit_width = json_type.FindMember("bitWidth"); - RETURN_NOT_INT("bitWidth", json_bit_width, json_type); + const auto& it_bit_width = json_type.FindMember("bitWidth"); + RETURN_NOT_INT("bitWidth", it_bit_width, json_type); - std::string unit_str = json_unit->value.GetString(); + std::string unit_str = it_unit->value.GetString(); if (unit_str == "SECOND") { *type = time32(TimeUnit::SECOND); @@ -631,7 +706,7 @@ static Status GetTime(const RjObject& json_type, std::shared_ptr<DataType>* type const auto& fw_type = static_cast<const FixedWidthType&>(**type); - int bit_width = json_bit_width->value.GetInt(); + int bit_width = it_bit_width->value.GetInt(); if (bit_width != fw_type.bit_width()) { return Status::Invalid("Indicated bit width does not match unit"); } @@ -640,10 +715,10 @@ static Status GetTime(const RjObject& json_type, std::shared_ptr<DataType>* type } static Status GetTimestamp(const RjObject& json_type, std::shared_ptr<DataType>* type) { - const auto& json_unit = json_type.FindMember("unit"); - RETURN_NOT_STRING("unit", json_unit, json_type); + const auto& it_unit = json_type.FindMember("unit"); + RETURN_NOT_STRING("unit", it_unit, json_type); - std::string unit_str = json_unit->value.GetString(); + std::string unit_str = it_unit->value.GetString(); TimeUnit::type unit; if (unit_str == "SECOND") { @@ -660,11 +735,11 @@ static Status GetTimestamp(const RjObject& json_type, std::shared_ptr<DataType>* return Status::Invalid(ss.str()); } - const auto& json_tz = json_type.FindMember("timezone"); - if (json_tz == json_type.MemberEnd()) { + const auto& it_tz = json_type.FindMember("timezone"); + if (it_tz == json_type.MemberEnd()) { *type = timestamp(unit); } else { - *type = timestamp(unit, json_tz->value.GetString()); + *type = timestamp(unit, it_tz->value.GetString()); } return Status::OK(); @@ -673,10 +748,10 @@ static Status GetTimestamp(const RjObject& json_type, std::shared_ptr<DataType>* static Status GetUnion(const RjObject& json_type, const std::vector<std::shared_ptr<Field>>& children, std::shared_ptr<DataType>* type) { - const auto& json_mode = json_type.FindMember("mode"); - RETURN_NOT_STRING("mode", json_mode, json_type); + const auto& it_mode = json_type.FindMember("mode"); + RETURN_NOT_STRING("mode", it_mode, json_type); - std::string mode_str = json_mode->value.GetString(); + std::string mode_str = it_mode->value.GetString(); UnionMode mode; if (mode_str == "SPARSE") { @@ -689,11 +764,11 @@ static Status GetUnion(const RjObject& json_type, return Status::Invalid(ss.str()); } - const auto& json_type_codes = json_type.FindMember("typeIds"); - RETURN_NOT_ARRAY("typeIds", json_type_codes, json_type); + const auto& it_type_codes = json_type.FindMember("typeIds"); + RETURN_NOT_ARRAY("typeIds", it_type_codes, json_type); std::vector<uint8_t> type_codes; - const auto& id_array = json_type_codes->value.GetArray(); + const auto& id_array = it_type_codes->value.GetArray(); for (const rj::Value& val : id_array) { DCHECK(val.IsUint()); type_codes.push_back(static_cast<uint8_t>(val.GetUint())); @@ -707,10 +782,10 @@ static Status GetUnion(const RjObject& json_type, static Status GetType(const RjObject& json_type, const std::vector<std::shared_ptr<Field>>& children, std::shared_ptr<DataType>* type) { - const auto& json_type_name = json_type.FindMember("name"); - RETURN_NOT_STRING("name", json_type_name, json_type); + const auto& it_type_name = json_type.FindMember("name"); + RETURN_NOT_STRING("name", it_type_name, json_type); - std::string type_name = json_type_name->value.GetString(); + std::string type_name = it_type_name->value.GetString(); if (type_name == "int") { return GetInteger(json_type, type); @@ -733,6 +808,9 @@ static Status GetType(const RjObject& json_type, } else if (type_name == "timestamp") { return GetTimestamp(json_type, type); } else if (type_name == "list") { + if (children.size() != 1) { + return Status::Invalid("List must have exactly one child"); + } *type = list(children[0]); } else if (type_name == "struct") { *type = struct_(children); @@ -742,43 +820,83 @@ static Status GetType(const RjObject& json_type, return Status::OK(); } -static Status GetField(const rj::Value& obj, std::shared_ptr<Field>* field); +static Status GetField(const rj::Value& obj, const DictionaryMemo* dictionary_memo, + std::shared_ptr<Field>* field); -static Status GetFieldsFromArray( - const rj::Value& obj, std::vector<std::shared_ptr<Field>>* fields) { +static Status GetFieldsFromArray(const rj::Value& obj, + const DictionaryMemo* dictionary_memo, std::vector<std::shared_ptr<Field>>* fields) { const auto& values = obj.GetArray(); fields->resize(values.Size()); for (rj::SizeType i = 0; i < fields->size(); ++i) { - RETURN_NOT_OK(GetField(values[i], &(*fields)[i])); + RETURN_NOT_OK(GetField(values[i], dictionary_memo, &(*fields)[i])); } return Status::OK(); } -static Status GetField(const rj::Value& obj, std::shared_ptr<Field>* field) { +static Status ParseDictionary(const RjObject& obj, int64_t* id, bool* is_ordered, + std::shared_ptr<DataType>* index_type) { + int32_t int32_id; + RETURN_NOT_OK(GetObjectInt(obj, "id", &int32_id)); + *id = int32_id; + + RETURN_NOT_OK(GetObjectBool(obj, "isOrdered", is_ordered)); + + const auto& it_index_type = obj.FindMember("indexType"); + RETURN_NOT_OBJECT("indexType", it_index_type, obj); + + const auto& json_index_type = it_index_type->value.GetObject(); + + std::string type_name; + RETURN_NOT_OK(GetObjectString(json_index_type, "name", &type_name)); + if (type_name != "int") { + return Status::Invalid("Dictionary indices can only be integers"); + } + return GetInteger(json_index_type, index_type); +} + +static Status GetField(const rj::Value& obj, const DictionaryMemo* dictionary_memo, + std::shared_ptr<Field>* field) { if (!obj.IsObject()) { return Status::Invalid("Field was not a JSON object"); } const auto& json_field = obj.GetObject(); - const auto& json_name = json_field.FindMember("name"); - RETURN_NOT_STRING("name", json_name, json_field); + std::string name; + bool nullable; + RETURN_NOT_OK(GetObjectString(json_field, "name", &name)); + RETURN_NOT_OK(GetObjectBool(json_field, "nullable", &nullable)); - const auto& json_nullable = json_field.FindMember("nullable"); - RETURN_NOT_BOOL("nullable", json_nullable, json_field); + std::shared_ptr<DataType> type; - const auto& json_type = json_field.FindMember("type"); - RETURN_NOT_OBJECT("type", json_type, json_field); + const auto& it_dictionary = json_field.FindMember("dictionary"); + if (dictionary_memo != nullptr && it_dictionary != json_field.MemberEnd()) { + // Field is dictionary encoded. We must have already + RETURN_NOT_OBJECT("dictionary", it_dictionary, json_field); + int64_t dictionary_id; + bool is_ordered; + std::shared_ptr<DataType> index_type; + RETURN_NOT_OK(ParseDictionary( + it_dictionary->value.GetObject(), &dictionary_id, &is_ordered, &index_type)); - const auto& json_children = json_field.FindMember("children"); - RETURN_NOT_ARRAY("children", json_children, json_field); + std::shared_ptr<Array> dictionary; + RETURN_NOT_OK(dictionary_memo->GetDictionary(dictionary_id, &dictionary)); - std::vector<std::shared_ptr<Field>> children; - RETURN_NOT_OK(GetFieldsFromArray(json_children->value, &children)); + type = std::make_shared<DictionaryType>(index_type, dictionary, is_ordered); + } else { + // If the dictionary_memo was not passed, or if the field is not dictionary + // encoded, we are interested in the complete type including all children - std::shared_ptr<DataType> type; - RETURN_NOT_OK(GetType(json_type->value.GetObject(), children, &type)); + const auto& it_type = json_field.FindMember("type"); + RETURN_NOT_OBJECT("type", it_type, json_field); + + const auto& it_children = json_field.FindMember("children"); + RETURN_NOT_ARRAY("children", it_children, json_field); - *field = std::make_shared<Field>( - json_name->value.GetString(), type, json_nullable->value.GetBool()); + std::vector<std::shared_ptr<Field>> children; + RETURN_NOT_OK(GetFieldsFromArray(it_children->value, dictionary_memo, &children)); + RETURN_NOT_OK(GetType(it_type->value.GetObject(), children, &type)); + } + + *field = std::make_shared<Field>(name, type, nullable); return Status::OK(); } @@ -810,9 +928,13 @@ UnboxValue(const rj::Value& val) { return val.GetBool(); } -class JsonArrayReader { +class ArrayReader { public: - explicit JsonArrayReader(MemoryPool* pool) : pool_(pool) {} + explicit ArrayReader(const rj::Value& json_array, const std::shared_ptr<DataType>& type, + MemoryPool* pool) + : json_array_(json_array), type_(type), pool_(pool) {} + + Status ParseTypeValues(const DataType& type); Status GetValidityBuffer(const std::vector<bool>& is_valid, int32_t* null_count, std::shared_ptr<Buffer>* validity_buffer) { @@ -841,18 +963,17 @@ class JsonArrayReader { std::is_base_of<TimestampType, T>::value || std::is_base_of<TimeType, T>::value || std::is_base_of<BooleanType, T>::value, Status>::type - ReadArray(const RjObject& json_array, int32_t length, const std::vector<bool>& is_valid, - const std::shared_ptr<DataType>& type, std::shared_ptr<Array>* array) { - typename TypeTraits<T>::BuilderType builder(pool_, type); + Visit(const T& type) { + typename TypeTraits<T>::BuilderType builder(pool_, type_); - const auto& json_data = json_array.FindMember("DATA"); - RETURN_NOT_ARRAY("DATA", json_data, json_array); + const auto& json_data = obj_->FindMember("DATA"); + RETURN_NOT_ARRAY("DATA", json_data, *obj_); const auto& json_data_arr = json_data->value.GetArray(); - DCHECK_EQ(static_cast<int32_t>(json_data_arr.Size()), length); - for (int i = 0; i < length; ++i) { - if (!is_valid[i]) { + DCHECK_EQ(static_cast<int32_t>(json_data_arr.Size()), length_); + for (int i = 0; i < length_; ++i) { + if (!is_valid_[i]) { builder.AppendNull(); continue; } @@ -861,25 +982,24 @@ class JsonArrayReader { builder.Append(UnboxValue<T>(val)); } - return builder.Finish(array); + return builder.Finish(&result_); } template <typename T> - typename std::enable_if<std::is_base_of<BinaryType, T>::value, Status>::type ReadArray( - const RjObject& json_array, int32_t length, const std::vector<bool>& is_valid, - const std::shared_ptr<DataType>& type, std::shared_ptr<Array>* array) { + typename std::enable_if<std::is_base_of<BinaryType, T>::value, Status>::type Visit( + const T& type) { typename TypeTraits<T>::BuilderType builder(pool_); - const auto& json_data = json_array.FindMember("DATA"); - RETURN_NOT_ARRAY("DATA", json_data, json_array); + const auto& json_data = obj_->FindMember("DATA"); + RETURN_NOT_ARRAY("DATA", json_data, *obj_); const auto& json_data_arr = json_data->value.GetArray(); - DCHECK_EQ(static_cast<int32_t>(json_data_arr.Size()), length); + DCHECK_EQ(static_cast<int32_t>(json_data_arr.Size()), length_); auto byte_buffer = std::make_shared<PoolBuffer>(pool_); - for (int i = 0; i < length; ++i) { - if (!is_valid[i]) { + for (int i = 0; i < length_; ++i) { + if (!is_valid_[i]) { builder.AppendNull(); continue; } @@ -905,31 +1025,29 @@ class JsonArrayReader { } } - return builder.Finish(array); + return builder.Finish(&result_); } template <typename T> typename std::enable_if<std::is_base_of<FixedSizeBinaryType, T>::value, Status>::type - ReadArray(const RjObject& json_array, int32_t length, const std::vector<bool>& is_valid, - const std::shared_ptr<DataType>& type, std::shared_ptr<Array>* array) { - FixedSizeBinaryBuilder builder(pool_, type); + Visit(const T& type) { + FixedSizeBinaryBuilder builder(pool_, type_); - const auto& json_data = json_array.FindMember("DATA"); - RETURN_NOT_ARRAY("DATA", json_data, json_array); + const auto& json_data = obj_->FindMember("DATA"); + RETURN_NOT_ARRAY("DATA", json_data, *obj_); const auto& json_data_arr = json_data->value.GetArray(); - DCHECK_EQ(static_cast<int32_t>(json_data_arr.Size()), length); - - int32_t byte_width = static_cast<const FixedSizeBinaryType&>(*type).byte_width(); + DCHECK_EQ(static_cast<int32_t>(json_data_arr.Size()), length_); + int32_t byte_width = type.byte_width(); // Allocate space for parsed values std::shared_ptr<MutableBuffer> byte_buffer; RETURN_NOT_OK(AllocateBuffer(pool_, byte_width, &byte_buffer)); uint8_t* byte_buffer_data = byte_buffer->mutable_data(); - for (int i = 0; i < length; ++i) { - if (!is_valid[i]) { + for (int i = 0; i < length_; ++i) { + if (!is_valid_[i]) { builder.AppendNull(); continue; } @@ -946,7 +1064,7 @@ class JsonArrayReader { } RETURN_NOT_OK(builder.Append(byte_buffer_data)); } - return builder.Finish(array); + return builder.Finish(&result_); } template <typename T> @@ -966,99 +1084,97 @@ class JsonArrayReader { return Status::OK(); } - template <typename T> - typename std::enable_if<std::is_base_of<ListType, T>::value, Status>::type ReadArray( - const RjObject& json_array, int32_t length, const std::vector<bool>& is_valid, - const std::shared_ptr<DataType>& type, std::shared_ptr<Array>* array) { + Status Visit(const ListType& type) { int32_t null_count = 0; std::shared_ptr<Buffer> validity_buffer; - RETURN_NOT_OK(GetValidityBuffer(is_valid, &null_count, &validity_buffer)); + RETURN_NOT_OK(GetValidityBuffer(is_valid_, &null_count, &validity_buffer)); - const auto& json_offsets = json_array.FindMember("OFFSET"); - RETURN_NOT_ARRAY("OFFSET", json_offsets, json_array); + const auto& json_offsets = obj_->FindMember("OFFSET"); + RETURN_NOT_ARRAY("OFFSET", json_offsets, *obj_); std::shared_ptr<Buffer> offsets_buffer; RETURN_NOT_OK(GetIntArray<int32_t>( - json_offsets->value.GetArray(), length + 1, &offsets_buffer)); + json_offsets->value.GetArray(), length_ + 1, &offsets_buffer)); std::vector<std::shared_ptr<Array>> children; - RETURN_NOT_OK(GetChildren(json_array, type, &children)); + RETURN_NOT_OK(GetChildren(*obj_, type, &children)); DCHECK_EQ(children.size(), 1); - *array = std::make_shared<ListArray>( - type, length, offsets_buffer, children[0], validity_buffer, null_count); + result_ = std::make_shared<ListArray>( + type_, length_, offsets_buffer, children[0], validity_buffer, null_count); return Status::OK(); } - template <typename T> - typename std::enable_if<std::is_base_of<StructType, T>::value, Status>::type ReadArray( - const RjObject& json_array, int32_t length, const std::vector<bool>& is_valid, - const std::shared_ptr<DataType>& type, std::shared_ptr<Array>* array) { + Status Visit(const StructType& type) { int32_t null_count = 0; std::shared_ptr<Buffer> validity_buffer; - RETURN_NOT_OK(GetValidityBuffer(is_valid, &null_count, &validity_buffer)); + RETURN_NOT_OK(GetValidityBuffer(is_valid_, &null_count, &validity_buffer)); std::vector<std::shared_ptr<Array>> fields; - RETURN_NOT_OK(GetChildren(json_array, type, &fields)); + RETURN_NOT_OK(GetChildren(*obj_, type, &fields)); - *array = - std::make_shared<StructArray>(type, length, fields, validity_buffer, null_count); + result_ = std::make_shared<StructArray>( + type_, length_, fields, validity_buffer, null_count); return Status::OK(); } - template <typename T> - typename std::enable_if<std::is_base_of<UnionType, T>::value, Status>::type ReadArray( - const RjObject& json_array, int32_t length, const std::vector<bool>& is_valid, - const std::shared_ptr<DataType>& type, std::shared_ptr<Array>* array) { + Status Visit(const UnionType& type) { int32_t null_count = 0; - const auto& union_type = static_cast<const UnionType&>(*type.get()); - std::shared_ptr<Buffer> validity_buffer; std::shared_ptr<Buffer> type_id_buffer; std::shared_ptr<Buffer> offsets_buffer; - RETURN_NOT_OK(GetValidityBuffer(is_valid, &null_count, &validity_buffer)); + RETURN_NOT_OK(GetValidityBuffer(is_valid_, &null_count, &validity_buffer)); - const auto& json_type_ids = json_array.FindMember("TYPE_ID"); - RETURN_NOT_ARRAY("TYPE_ID", json_type_ids, json_array); + const auto& json_type_ids = obj_->FindMember("TYPE_ID"); + RETURN_NOT_ARRAY("TYPE_ID", json_type_ids, *obj_); RETURN_NOT_OK( - GetIntArray<uint8_t>(json_type_ids->value.GetArray(), length, &type_id_buffer)); + GetIntArray<uint8_t>(json_type_ids->value.GetArray(), length_, &type_id_buffer)); - if (union_type.mode() == UnionMode::DENSE) { - const auto& json_offsets = json_array.FindMember("OFFSET"); - RETURN_NOT_ARRAY("OFFSET", json_offsets, json_array); + if (type.mode() == UnionMode::DENSE) { + const auto& json_offsets = obj_->FindMember("OFFSET"); + RETURN_NOT_ARRAY("OFFSET", json_offsets, *obj_); RETURN_NOT_OK( - GetIntArray<int32_t>(json_offsets->value.GetArray(), length, &offsets_buffer)); + GetIntArray<int32_t>(json_offsets->value.GetArray(), length_, &offsets_buffer)); } std::vector<std::shared_ptr<Array>> children; - RETURN_NOT_OK(GetChildren(json_array, type, &children)); + RETURN_NOT_OK(GetChildren(*obj_, type, &children)); - *array = std::make_shared<UnionArray>(type, length, children, type_id_buffer, + result_ = std::make_shared<UnionArray>(type_, length_, children, type_id_buffer, offsets_buffer, validity_buffer, null_count); return Status::OK(); } - template <typename T> - typename std::enable_if<std::is_base_of<NullType, T>::value, Status>::type ReadArray( - const RjObject& json_array, int32_t length, const std::vector<bool>& is_valid, - const std::shared_ptr<DataType>& type, std::shared_ptr<Array>* array) { - *array = std::make_shared<NullArray>(length); + Status Visit(const NullType& type) { + result_ = std::make_shared<NullArray>(length_); + return Status::OK(); + } + + Status Visit(const DictionaryType& type) { + // This stores the indices in result_ + // + // XXX(wesm): slight hack + auto dict_type = type_; + type_ = type.index_type(); + RETURN_NOT_OK(ParseTypeValues(*type_)); + type_ = dict_type; + result_ = std::make_shared<DictionaryArray>(type_, result_); return Status::OK(); } - Status GetChildren(const RjObject& json_array, const std::shared_ptr<DataType>& type, + Status GetChildren(const RjObject& obj, const DataType& type, std::vector<std::shared_ptr<Array>>* array) { - const auto& json_children = json_array.FindMember("children"); - RETURN_NOT_ARRAY("children", json_children, json_array); + const auto& json_children = obj.FindMember("children"); + RETURN_NOT_ARRAY("children", json_children, obj); const auto& json_children_arr = json_children->value.GetArray(); - if (type->num_children() != static_cast<int>(json_children_arr.Size())) { + if (type.num_children() != static_cast<int>(json_children_arr.Size())) { std::stringstream ss; - ss << "Expected " << type->num_children() << " children, but got " + ss << "Expected " << type.num_children() << " children, but got " << json_children_arr.Size(); return Status::Invalid(ss.str()); } @@ -1067,128 +1183,237 @@ class JsonArrayReader { const rj::Value& json_child = json_children_arr[i]; DCHECK(json_child.IsObject()); - std::shared_ptr<Field> child_field = type->child(i); + std::shared_ptr<Field> child_field = type.child(i); auto it = json_child.FindMember("name"); RETURN_NOT_STRING("name", it, json_child); DCHECK_EQ(it->value.GetString(), child_field->name()); std::shared_ptr<Array> child; - RETURN_NOT_OK(GetArray(json_children_arr[i], child_field->type(), &child)); + RETURN_NOT_OK(ReadArray(pool_, json_children_arr[i], child_field->type(), &child)); array->emplace_back(child); } return Status::OK(); } - Status GetArray(const rj::Value& obj, const std::shared_ptr<DataType>& type, - std::shared_ptr<Array>* array) { - if (!obj.IsObject()) { + Status GetArray(std::shared_ptr<Array>* out) { + if (!json_array_.IsObject()) { return Status::Invalid("Array element was not a JSON object"); } - const auto& json_array = obj.GetObject(); - const auto& json_length = json_array.FindMember("count"); - RETURN_NOT_INT("count", json_length, json_array); - int32_t length = json_length->value.GetInt(); + auto obj = json_array_.GetObject(); + obj_ = &obj; - const auto& json_valid_iter = json_array.FindMember("VALIDITY"); - RETURN_NOT_ARRAY("VALIDITY", json_valid_iter, json_array); + RETURN_NOT_OK(GetObjectInt(obj, "count", &length_)); - const auto& json_validity = json_valid_iter->value.GetArray(); - - DCHECK_EQ(static_cast<int>(json_validity.Size()), length); + const auto& json_valid_iter = obj.FindMember("VALIDITY"); + RETURN_NOT_ARRAY("VALIDITY", json_valid_iter, obj); - std::vector<bool> is_valid; + const auto& json_validity = json_valid_iter->value.GetArray(); + DCHECK_EQ(static_cast<int>(json_validity.Size()), length_); for (const rj::Value& val : json_validity) { DCHECK(val.IsInt()); - is_valid.push_back(val.GetInt() != 0); - } - -#define TYPE_CASE(TYPE) \ - case TYPE::type_id: \ - return ReadArray<TYPE>(json_array, length, is_valid, type, array); - - switch (type->id()) { - TYPE_CASE(NullType); - TYPE_CASE(BooleanType); - TYPE_CASE(UInt8Type); - TYPE_CASE(Int8Type); - TYPE_CASE(UInt16Type); - TYPE_CASE(Int16Type); - TYPE_CASE(UInt32Type); - TYPE_CASE(Int32Type); - TYPE_CASE(UInt64Type); - TYPE_CASE(Int64Type); - TYPE_CASE(HalfFloatType); - TYPE_CASE(FloatType); - TYPE_CASE(DoubleType); - TYPE_CASE(StringType); - TYPE_CASE(BinaryType); - TYPE_CASE(FixedSizeBinaryType); - TYPE_CASE(Date32Type); - TYPE_CASE(Date64Type); - TYPE_CASE(TimestampType); - TYPE_CASE(Time32Type); - TYPE_CASE(Time64Type); - TYPE_CASE(ListType); - TYPE_CASE(StructType); - TYPE_CASE(UnionType); - default: - std::stringstream ss; - ss << type->ToString(); - return Status::NotImplemented(ss.str()); + is_valid_.push_back(val.GetInt() != 0); } -#undef TYPE_CASE - + RETURN_NOT_OK(ParseTypeValues(*type_)); + *out = result_; return Status::OK(); } private: + const rj::Value& json_array_; + const RjObject* obj_; + std::shared_ptr<DataType> type_; MemoryPool* pool_; + + // Parsed common attributes + std::vector<bool> is_valid_; + int32_t length_; + std::shared_ptr<Array> result_; }; -Status WriteJsonSchema(const Schema& schema, RjWriter* json_writer) { - JsonSchemaWriter converter(schema, json_writer); +Status ArrayReader::ParseTypeValues(const DataType& type) { + return VisitTypeInline(type, this); +} + +Status WriteSchema(const Schema& schema, RjWriter* json_writer) { + SchemaWriter converter(schema, json_writer); return converter.Write(); } -Status ReadJsonSchema(const rj::Value& json_schema, std::shared_ptr<Schema>* schema) { - const auto& obj_schema = json_schema.GetObject(); +static Status LookForDictionaries(const rj::Value& obj, DictionaryTypeMap* id_to_field) { + const auto& json_field = obj.GetObject(); + + const auto& it_dictionary = json_field.FindMember("dictionary"); + if (it_dictionary == json_field.MemberEnd()) { + // Not dictionary-encoded + return Status::OK(); + } + + // Dictionary encoded. Construct the field and set in the type map + std::shared_ptr<Field> dictionary_field; + RETURN_NOT_OK(GetField(obj, nullptr, &dictionary_field)); + + int id; + RETURN_NOT_OK(GetObjectInt(it_dictionary->value.GetObject(), "id", &id)); + (*id_to_field)[id] = dictionary_field; + return Status::OK(); +} + +static Status GetDictionaryTypes(const RjArray& fields, DictionaryTypeMap* id_to_field) { + for (rj::SizeType i = 0; i < fields.Size(); ++i) { + RETURN_NOT_OK(LookForDictionaries(fields[i], id_to_field)); + } + return Status::OK(); +} + +static Status ReadDictionary(const RjObject& obj, const DictionaryTypeMap& id_to_field, + MemoryPool* pool, int64_t* dictionary_id, std::shared_ptr<Array>* out) { + int id; + RETURN_NOT_OK(GetObjectInt(obj, "id", &id)); - const auto& json_fields = obj_schema.FindMember("fields"); - RETURN_NOT_ARRAY("fields", json_fields, obj_schema); + const auto& it_data = obj.FindMember("data"); + RETURN_NOT_OBJECT("data", it_data, obj); + + auto it = id_to_field.find(id); + if (it == id_to_field.end()) { + std::stringstream ss; + ss << "No dictionary with id " << id; + return Status::Invalid(ss.str()); + } + std::vector<std::shared_ptr<Field>> fields = {it->second}; + + // We need a schema for the record batch + auto dummy_schema = std::make_shared<Schema>(fields); + + // The dictionary is embedded in a record batch with a single column + std::shared_ptr<RecordBatch> batch; + RETURN_NOT_OK(ReadRecordBatch(it_data->value, dummy_schema, pool, &batch)); + + if (batch->num_columns() != 1) { + return Status::Invalid("Dictionary record batch must only contain one field"); + } + + *dictionary_id = id; + *out = batch->column(0); + return Status::OK(); +} + +static Status ReadDictionaries(const rj::Value& doc, const DictionaryTypeMap& id_to_field, + MemoryPool* pool, DictionaryMemo* dictionary_memo) { + auto it = doc.FindMember("dictionaries"); + if (it == doc.MemberEnd()) { + // No dictionaries + return Status::OK(); + } + + RETURN_NOT_ARRAY("dictionaries", it, doc); + const auto& dictionary_array = it->value.GetArray(); + + for (const rj::Value& val : dictionary_array) { + DCHECK(val.IsObject()); + int64_t dictionary_id; + std::shared_ptr<Array> dictionary; + RETURN_NOT_OK( + ReadDictionary(val.GetObject(), id_to_field, pool, &dictionary_id, &dictionary)); + + RETURN_NOT_OK(dictionary_memo->AddDictionary(dictionary_id, dictionary)); + } + return Status::OK(); +} + +Status ReadSchema( + const rj::Value& json_schema, MemoryPool* pool, std::shared_ptr<Schema>* schema) { + auto it = json_schema.FindMember("schema"); + RETURN_NOT_OBJECT("schema", it, json_schema); + const auto& obj_schema = it->value.GetObject(); + + const auto& it_fields = obj_schema.FindMember("fields"); + RETURN_NOT_ARRAY("fields", it_fields, obj_schema); + + // Determine the dictionary types + DictionaryTypeMap dictionary_types; + RETURN_NOT_OK(GetDictionaryTypes(it_fields->value.GetArray(), &dictionary_types)); + + // Read the dictionaries (if any) and cache in the memo + DictionaryMemo dictionary_memo; + RETURN_NOT_OK(ReadDictionaries(json_schema, dictionary_types, pool, &dictionary_memo)); std::vector<std::shared_ptr<Field>> fields; - RETURN_NOT_OK(GetFieldsFromArray(json_fields->value, &fields)); + RETURN_NOT_OK(GetFieldsFromArray(it_fields->value, &dictionary_memo, &fields)); *schema = std::make_shared<Schema>(fields); return Status::OK(); } -Status WriteJsonArray( - const std::string& name, const Array& array, RjWriter* json_writer) { - JsonArrayWriter converter(name, array, json_writer); +Status ReadRecordBatch(const rj::Value& json_obj, const std::shared_ptr<Schema>& schema, + MemoryPool* pool, std::shared_ptr<RecordBatch>* batch) { + DCHECK(json_obj.IsObject()); + const auto& batch_obj = json_obj.GetObject(); + + auto it = batch_obj.FindMember("count"); + RETURN_NOT_INT("count", it, batch_obj); + int32_t num_rows = static_cast<int32_t>(it->value.GetInt()); + + it = batch_obj.FindMember("columns"); + RETURN_NOT_ARRAY("columns", it, batch_obj); + const auto& json_columns = it->value.GetArray(); + + std::vector<std::shared_ptr<Array>> columns(json_columns.Size()); + for (int i = 0; i < static_cast<int>(columns.size()); ++i) { + const std::shared_ptr<DataType>& type = schema->field(i)->type(); + RETURN_NOT_OK(ReadArray(pool, json_columns[i], type, &columns[i])); + } + + *batch = std::make_shared<RecordBatch>(schema, num_rows, columns); + return Status::OK(); +} + +Status WriteRecordBatch(const RecordBatch& batch, RjWriter* writer) { + writer->StartObject(); + writer->Key("count"); + writer->Int(static_cast<int32_t>(batch.num_rows())); + + writer->Key("columns"); + writer->StartArray(); + + for (int i = 0; i < batch.num_columns(); ++i) { + const std::shared_ptr<Array>& column = batch.column(i); + + DCHECK_EQ(batch.num_rows(), column->length()) + << "Array length did not match record batch length"; + + RETURN_NOT_OK(WriteArray(batch.column_name(i), *column, writer)); + } + + writer->EndArray(); + writer->EndObject(); + return Status::OK(); +} + +Status WriteArray(const std::string& name, const Array& array, RjWriter* json_writer) { + ArrayWriter converter(name, array, json_writer); return converter.Write(); } -Status ReadJsonArray(MemoryPool* pool, const rj::Value& json_array, +Status ReadArray(MemoryPool* pool, const rj::Value& json_array, const std::shared_ptr<DataType>& type, std::shared_ptr<Array>* array) { - JsonArrayReader converter(pool); - return converter.GetArray(json_array, type, array); + ArrayReader converter(json_array, type, pool); + return converter.GetArray(array); } -Status ReadJsonArray(MemoryPool* pool, const rj::Value& json_array, const Schema& schema, +Status ReadArray(MemoryPool* pool, const rj::Value& json_array, const Schema& schema, std::shared_ptr<Array>* array) { if (!json_array.IsObject()) { return Status::Invalid("Element was not a JSON object"); } const auto& json_obj = json_array.GetObject(); - const auto& json_name = json_obj.FindMember("name"); - RETURN_NOT_STRING("name", json_name, json_obj); + const auto& it_name = json_obj.FindMember("name"); + RETURN_NOT_STRING("name", it_name, json_obj); - std::string name = json_name->value.GetString(); + std::string name = it_name->value.GetString(); std::shared_ptr<Field> result = nullptr; for (const std::shared_ptr<Field>& field : schema.fields()) { @@ -1204,8 +1429,10 @@ Status ReadJsonArray(MemoryPool* pool, const rj::Value& json_array, const Schema return Status::KeyError(ss.str()); } - return ReadJsonArray(pool, json_array, result->type(), array); + return ReadArray(pool, json_array, result->type(), array); } +} // namespace internal +} // namespace json } // namespace ipc } // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/25ba44c3/cpp/src/arrow/ipc/json-internal.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/json-internal.h b/cpp/src/arrow/ipc/json-internal.h index 0c167a4..5571d92 100644 --- a/cpp/src/arrow/ipc/json-internal.h +++ b/cpp/src/arrow/ipc/json-internal.h @@ -35,9 +35,11 @@ namespace rj = rapidjson; using RjWriter = rj::Writer<rj::StringBuffer>; +using RjArray = rj::Value::ConstArray; +using RjObject = rj::Value::ConstObject; #define RETURN_NOT_FOUND(TOK, NAME, PARENT) \ - if (NAME == PARENT.MemberEnd()) { \ + if (NAME == (PARENT).MemberEnd()) { \ std::stringstream ss; \ ss << "field " << TOK << " not found"; \ return Status::Invalid(ss.str()); \ @@ -90,21 +92,27 @@ using RjWriter = rj::Writer<rj::StringBuffer>; namespace arrow { namespace ipc { +namespace json { +namespace internal { -// TODO(wesm): Only exporting these because arrow_ipc does not have a static -// library at the moment. Better to not export -Status ARROW_EXPORT WriteJsonSchema(const Schema& schema, RjWriter* json_writer); -Status ARROW_EXPORT WriteJsonArray( - const std::string& name, const Array& array, RjWriter* json_writer); +Status WriteSchema(const Schema& schema, RjWriter* writer); +Status WriteRecordBatch(const RecordBatch& batch, RjWriter* writer); +Status WriteArray(const std::string& name, const Array& array, RjWriter* writer); -Status ARROW_EXPORT ReadJsonSchema( - const rj::Value& json_obj, std::shared_ptr<Schema>* schema); -Status ARROW_EXPORT ReadJsonArray(MemoryPool* pool, const rj::Value& json_obj, +Status ReadSchema( + const rj::Value& json_obj, MemoryPool* pool, std::shared_ptr<Schema>* schema); + +Status ReadRecordBatch(const rj::Value& json_obj, const std::shared_ptr<Schema>& schema, + MemoryPool* pool, std::shared_ptr<RecordBatch>* batch); + +Status ReadArray(MemoryPool* pool, const rj::Value& json_obj, const std::shared_ptr<DataType>& type, std::shared_ptr<Array>* array); -Status ARROW_EXPORT ReadJsonArray(MemoryPool* pool, const rj::Value& json_obj, - const Schema& schema, std::shared_ptr<Array>* array); +Status ReadArray(MemoryPool* pool, const rj::Value& json_obj, const Schema& schema, + std::shared_ptr<Array>* array); +} // namespace internal +} // namespace json } // namespace ipc } // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/25ba44c3/cpp/src/arrow/ipc/json.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/json.cc b/cpp/src/arrow/ipc/json.cc index 0abd6d7..f8c0b62 100644 --- a/cpp/src/arrow/ipc/json.cc +++ b/cpp/src/arrow/ipc/json.cc @@ -45,9 +45,7 @@ class JsonWriter::JsonWriterImpl { Status Start() { writer_->StartObject(); - - writer_->Key("schema"); - RETURN_NOT_OK(WriteJsonSchema(*schema_.get(), writer_.get())); + RETURN_NOT_OK(json::internal::WriteSchema(*schema_, writer_.get())); // Record batches writer_->Key("batches"); @@ -65,26 +63,7 @@ class JsonWriter::JsonWriterImpl { Status WriteRecordBatch(const RecordBatch& batch) { DCHECK_EQ(batch.num_columns(), schema_->num_fields()); - - writer_->StartObject(); - writer_->Key("count"); - writer_->Int(static_cast<int32_t>(batch.num_rows())); - - writer_->Key("columns"); - writer_->StartArray(); - - for (int i = 0; i < schema_->num_fields(); ++i) { - const std::shared_ptr<Array>& column = batch.column(i); - - DCHECK_EQ(batch.num_rows(), column->length()) - << "Array length did not match record batch length"; - - RETURN_NOT_OK(WriteJsonArray(schema_->field(i)->name(), *column, writer_.get())); - } - - writer_->EndArray(); - writer_->EndObject(); - return Status::OK(); + return json::internal::WriteRecordBatch(batch, writer_.get()); } private: @@ -127,11 +106,9 @@ class JsonReader::JsonReaderImpl { static_cast<size_t>(data_->size())); if (doc_.HasParseError()) { return Status::IOError("JSON parsing failed"); } - auto it = doc_.FindMember("schema"); - RETURN_NOT_OBJECT("schema", it, doc_); - RETURN_NOT_OK(ReadJsonSchema(it->value, &schema_)); + RETURN_NOT_OK(json::internal::ReadSchema(doc_, pool_, &schema_)); - it = doc_.FindMember("batches"); + auto it = doc_.FindMember("batches"); RETURN_NOT_ARRAY("batches", it, doc_); record_batches_ = &it->value; @@ -143,27 +120,8 @@ class JsonReader::JsonReaderImpl { DCHECK_LT(i, static_cast<int>(record_batches_->GetArray().Size())) << "i out of bounds"; - const auto& batch_val = record_batches_->GetArray()[i]; - DCHECK(batch_val.IsObject()); - - const auto& batch_obj = batch_val.GetObject(); - - auto it = batch_obj.FindMember("count"); - RETURN_NOT_INT("count", it, batch_obj); - int32_t num_rows = static_cast<int32_t>(it->value.GetInt()); - - it = batch_obj.FindMember("columns"); - RETURN_NOT_ARRAY("columns", it, batch_obj); - const auto& json_columns = it->value.GetArray(); - - std::vector<std::shared_ptr<Array>> columns(json_columns.Size()); - for (int i = 0; i < static_cast<int>(columns.size()); ++i) { - const std::shared_ptr<DataType>& type = schema_->field(i)->type(); - RETURN_NOT_OK(ReadJsonArray(pool_, json_columns[i], type, &columns[i])); - } - - *batch = std::make_shared<RecordBatch>(schema_, num_rows, columns); - return Status::OK(); + return json::internal::ReadRecordBatch( + record_batches_->GetArray()[i], schema_, pool_, batch); } std::shared_ptr<Schema> schema() const { return schema_; } @@ -178,7 +136,6 @@ class JsonReader::JsonReaderImpl { rj::Document doc_; const rj::Value* record_batches_; - std::shared_ptr<Schema> schema_; }; http://git-wip-us.apache.org/repos/asf/arrow/blob/25ba44c3/cpp/src/arrow/ipc/json.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/json.h b/cpp/src/arrow/ipc/json.h index 0d88cef..ad94def 100644 --- a/cpp/src/arrow/ipc/json.h +++ b/cpp/src/arrow/ipc/json.h @@ -44,10 +44,7 @@ class ARROW_EXPORT JsonWriter { static Status Open( const std::shared_ptr<Schema>& schema, std::unique_ptr<JsonWriter>* out); - // TODO(wesm): Write dictionaries - Status WriteRecordBatch(const RecordBatch& batch); - Status Finish(std::string* result); private: http://git-wip-us.apache.org/repos/asf/arrow/blob/25ba44c3/cpp/src/arrow/ipc/metadata.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/metadata.h b/cpp/src/arrow/ipc/metadata.h index 84026c4..ec7bc39 100644 --- a/cpp/src/arrow/ipc/metadata.h +++ b/cpp/src/arrow/ipc/metadata.h @@ -77,6 +77,7 @@ class ARROW_EXPORT DictionaryMemo { // Returns KeyError if dictionary not found Status GetDictionary(int64_t id, std::shared_ptr<Array>* dictionary) const; + /// Return id for dictionary, computing new id if necessary int64_t GetId(const std::shared_ptr<Array>& dictionary); bool HasDictionary(const std::shared_ptr<Array>& dictionary) const; @@ -88,6 +89,8 @@ class ARROW_EXPORT DictionaryMemo { const DictionaryMap& id_to_dictionary() const { return id_to_dictionary_; } + int size() const { return static_cast<int>(id_to_dictionary_.size()); } + private: // Dictionary memory addresses, to track whether a dictionary has been seen // before http://git-wip-us.apache.org/repos/asf/arrow/blob/25ba44c3/cpp/src/arrow/ipc/test-common.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/test-common.h b/cpp/src/arrow/ipc/test-common.h index 5caa3a9..deaeb59 100644 --- a/cpp/src/arrow/ipc/test-common.h +++ b/cpp/src/arrow/ipc/test-common.h @@ -55,10 +55,17 @@ static inline void CompareBatch(const RecordBatch& left, const RecordBatch& righ } ASSERT_EQ(left.num_columns(), right.num_columns()) << left.schema()->ToString() << " result: " << right.schema()->ToString(); - EXPECT_EQ(left.num_rows(), right.num_rows()); + ASSERT_EQ(left.num_rows(), right.num_rows()); for (int i = 0; i < left.num_columns(); ++i) { - EXPECT_TRUE(left.column(i)->Equals(right.column(i))) - << "Idx: " << i << " Name: " << left.column_name(i); + if (!left.column(i)->Equals(right.column(i))) { + std::stringstream ss; + ss << "Idx: " << i << " Name: " << left.column_name(i); + ss << std::endl << "Left: "; + ASSERT_OK(PrettyPrint(*left.column(i), 0, &ss)); + ss << std::endl << "Right: "; + ASSERT_OK(PrettyPrint(*right.column(i), 0, &ss)); + FAIL() << ss.str(); + } } } http://git-wip-us.apache.org/repos/asf/arrow/blob/25ba44c3/cpp/src/arrow/ipc/writer.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/writer.cc b/cpp/src/arrow/ipc/writer.cc index 5d4b94a..60b1f47 100644 --- a/cpp/src/arrow/ipc/writer.cc +++ b/cpp/src/arrow/ipc/writer.cc @@ -591,10 +591,7 @@ RecordBatchWriter::~RecordBatchWriter() {} class RecordBatchStreamWriter::RecordBatchStreamWriterImpl { public: RecordBatchStreamWriterImpl() - : dictionary_memo_(std::make_shared<DictionaryMemo>()), - pool_(default_memory_pool()), - position_(-1), - started_(false) {} + : pool_(default_memory_pool()), position_(-1), started_(false) {} virtual ~RecordBatchStreamWriterImpl() = default; @@ -606,7 +603,7 @@ class RecordBatchStreamWriter::RecordBatchStreamWriterImpl { virtual Status Start() { std::shared_ptr<Buffer> schema_fb; - RETURN_NOT_OK(WriteSchemaMessage(*schema_, dictionary_memo_.get(), &schema_fb)); + RETURN_NOT_OK(WriteSchemaMessage(*schema_, &dictionary_memo_, &schema_fb)); int32_t flatbuffer_size = static_cast<int32_t>(schema_fb->size()); RETURN_NOT_OK( @@ -640,7 +637,7 @@ class RecordBatchStreamWriter::RecordBatchStreamWriterImpl { Status UpdatePosition() { return sink_->Tell(&position_); } Status WriteDictionaries() { - const DictionaryMap& id_to_dictionary = dictionary_memo_->id_to_dictionary(); + const DictionaryMap& id_to_dictionary = dictionary_memo_.id_to_dictionary(); dictionaries_.resize(id_to_dictionary.size()); @@ -709,7 +706,7 @@ class RecordBatchStreamWriter::RecordBatchStreamWriterImpl { // When writing out the schema, we keep track of all the dictionaries we // encounter, as they must be written out first in the stream - std::shared_ptr<DictionaryMemo> dictionary_memo_; + DictionaryMemo dictionary_memo_; MemoryPool* pool_; @@ -770,7 +767,7 @@ class RecordBatchFileWriter::RecordBatchFileWriterImpl // Write metadata int64_t initial_position = position_; RETURN_NOT_OK(WriteFileFooter( - *schema_, dictionaries_, record_batches_, dictionary_memo_.get(), sink_)); + *schema_, dictionaries_, record_batches_, &dictionary_memo_, sink_)); RETURN_NOT_OK(UpdatePosition()); // Write footer length