lidavidm commented on a change in pull request #11446:
URL: https://github.com/apache/arrow/pull/11446#discussion_r741430380
##########
File path: cpp/src/arrow/compute/exec/hash_join_node_test.cc
##########
@@ -1113,5 +1113,539 @@ TEST(HashJoin, Random) {
}
}
+void DecodeScalarsAndDictionariesInBatch(ExecBatch* batch, MemoryPool* pool) {
+ for (size_t i = 0; i < batch->values.size(); ++i) {
+ if (batch->values[i].is_scalar()) {
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<Array> col,
+ MakeArrayFromScalar(*(batch->values[i].scalar()), batch->length,
pool));
+ batch->values[i] = Datum(col);
+ }
+ if (batch->values[i].type()->id() == Type::DICTIONARY) {
+ const auto& dict_type =
+ checked_cast<const DictionaryType&>(*batch->values[i].type());
+ std::shared_ptr<ArrayData> indices =
+ ArrayData::Make(dict_type.index_type(),
batch->values[i].array()->length,
+ batch->values[i].array()->buffers);
+ const std::shared_ptr<ArrayData>& dictionary =
batch->values[i].array()->dictionary;
+ ASSERT_OK_AND_ASSIGN(Datum col, Take(*dictionary, *indices));
+ batch->values[i] = col;
+ }
+ }
+}
+
+std::shared_ptr<Schema> UpdateSchemaAfterDecodingDictionaries(
+ const std::shared_ptr<Schema>& schema) {
+ std::vector<std::shared_ptr<Field>> output_fields(schema->num_fields());
+ for (int i = 0; i < schema->num_fields(); ++i) {
+ const std::shared_ptr<Field>& field = schema->field(i);
+ if (field->type()->id() == Type::DICTIONARY) {
+ const auto& dict_type = checked_cast<const
DictionaryType&>(*field->type());
+ output_fields[i] = std::make_shared<Field>(field->name(),
dict_type.value_type(),
+ true /* nullable */);
+ } else {
+ output_fields[i] = field->Copy();
+ }
+ }
+ return std::make_shared<Schema>(std::move(output_fields));
+}
+
+void TestHashJoinDictionaryHelper(
+ JoinType join_type, JoinKeyCmp cmp,
+ // Whether to run parallel hash join.
+ // This requires generating multiple copies of each input batch on one
side of the
+ // join. Expected results will be automatically adjusted to reflect the
multiplication
+ // of input batches.
+ bool parallel, Datum l_key, Datum l_payload, Datum r_key, Datum r_payload,
+ Datum l_out_key, Datum l_out_payload, Datum r_out_key, Datum r_out_payload,
+ // Number of rows at the end of expected output that represent rows from
the right
+ // side that do not have a match on the left side. This number is needed to
+ // automatically adjust expected result when multiplying input batches on
the left
+ // side.
+ int expected_num_r_no_match,
+ // Whether to swap two inputs to the hash join
+ bool swap_sides) {
+ int64_t l_length = l_key.is_array()
+ ? l_key.array()->length
+ : l_payload.is_array() ? l_payload.array()->length :
-1;
+ int64_t r_length = r_key.is_array()
+ ? r_key.array()->length
+ : r_payload.is_array() ? r_payload.array()->length :
-1;
+ ARROW_DCHECK(l_length >= 0 && r_length >= 0);
+
+ constexpr int batch_multiplicity_for_parallel = 2;
+
+ // Split both sides into exactly two batches
+ int64_t l_first_length = l_length / 2;
+ int64_t r_first_length = r_length / 2;
+ BatchesWithSchema l_batches, r_batches;
+ l_batches.batches.resize(2);
+ r_batches.batches.resize(2);
+ ASSERT_OK_AND_ASSIGN(
+ l_batches.batches[0],
+ ExecBatch::Make({l_key.is_array() ? l_key.array()->Slice(0,
l_first_length) : l_key,
+ l_payload.is_array() ? l_payload.array()->Slice(0,
l_first_length)
+ : l_payload}));
+ ASSERT_OK_AND_ASSIGN(
+ l_batches.batches[1],
+ ExecBatch::Make(
+ {l_key.is_array()
+ ? l_key.array()->Slice(l_first_length, l_length -
l_first_length)
+ : l_key,
+ l_payload.is_array()
+ ? l_payload.array()->Slice(l_first_length, l_length -
l_first_length)
+ : l_payload}));
+ ASSERT_OK_AND_ASSIGN(
+ r_batches.batches[0],
+ ExecBatch::Make({r_key.is_array() ? r_key.array()->Slice(0,
r_first_length) : r_key,
+ r_payload.is_array() ? r_payload.array()->Slice(0,
r_first_length)
+ : r_payload}));
+ ASSERT_OK_AND_ASSIGN(
+ r_batches.batches[1],
+ ExecBatch::Make(
+ {r_key.is_array()
+ ? r_key.array()->Slice(r_first_length, r_length -
r_first_length)
+ : r_key,
+ r_payload.is_array()
+ ? r_payload.array()->Slice(r_first_length, r_length -
r_first_length)
+ : r_payload}));
+ l_batches.schema =
+ schema({field("l_key", l_key.type()), field("l_payload",
l_payload.type())});
+ r_batches.schema =
+ schema({field("r_key", r_key.type()), field("r_payload",
r_payload.type())});
+
+ // Add copies of input batches on originally left side of the hash join
+ if (parallel) {
+ for (int i = 0; i < batch_multiplicity_for_parallel - 1; ++i) {
+ l_batches.batches.push_back(l_batches.batches[0]);
+ l_batches.batches.push_back(l_batches.batches[1]);
+ }
+ }
+
+ auto exec_ctx = arrow::internal::make_unique<ExecContext>(
+ default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() :
nullptr);
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
+ ASSERT_OK_AND_ASSIGN(
+ ExecNode * l_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{l_batches.schema, l_batches.gen(parallel,
+
/*slow=*/false)}));
+ ASSERT_OK_AND_ASSIGN(
+ ExecNode * r_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{r_batches.schema, r_batches.gen(parallel,
+
/*slow=*/false)}));
+ HashJoinNodeOptions join_options{join_type,
+ {FieldRef(swap_sides ? "r_key" : "l_key")},
+ {FieldRef(swap_sides ? "l_key" : "r_key")},
+ {FieldRef(swap_sides ? "r_key" : "l_key"),
+ FieldRef(swap_sides ? "r_payload" :
"l_payload")},
+ {FieldRef(swap_sides ? "l_key" : "r_key"),
+ FieldRef(swap_sides ? "l_payload" :
"r_payload")},
+ {cmp}};
+ ASSERT_OK_AND_ASSIGN(ExecNode * join, MakeExecNode("hashjoin", plan.get(),
+ {(swap_sides ? r_source :
l_source),
+ (swap_sides ? l_source :
r_source)},
+ join_options));
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+ ASSERT_OK_AND_ASSIGN(
+ std::ignore, MakeExecNode("sink", plan.get(), {join},
SinkNodeOptions{&sink_gen}));
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(),
sink_gen));
+
+ for (auto& batch : res) {
+ DecodeScalarsAndDictionariesInBatch(&batch, exec_ctx->memory_pool());
+ }
+ std::shared_ptr<Schema> output_schema =
+ UpdateSchemaAfterDecodingDictionaries(join->output_schema());
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> output,
+ TableFromExecBatches(output_schema, res));
+
+ ExecBatch expected_batch;
+ if (swap_sides) {
+ ASSERT_OK_AND_ASSIGN(expected_batch, ExecBatch::Make({r_out_key,
r_out_payload,
+ l_out_key,
l_out_payload}));
+ } else {
+ ASSERT_OK_AND_ASSIGN(expected_batch, ExecBatch::Make({l_out_key,
l_out_payload,
+ r_out_key,
r_out_payload}));
+ }
+
+ DecodeScalarsAndDictionariesInBatch(&expected_batch,
exec_ctx->memory_pool());
+
+ // Slice expected batch into two to separate rows on right side with no
matches from
+ // everything else.
+ //
+ std::vector<ExecBatch> expected_batches;
+ ASSERT_OK_AND_ASSIGN(
+ auto prefix_batch,
+ ExecBatch::Make({expected_batch.values[0].array()->Slice(
+ 0, expected_batch.length - expected_num_r_no_match),
+ expected_batch.values[1].array()->Slice(
+ 0, expected_batch.length - expected_num_r_no_match),
+ expected_batch.values[2].array()->Slice(
+ 0, expected_batch.length - expected_num_r_no_match),
+ expected_batch.values[3].array()->Slice(
+ 0, expected_batch.length -
expected_num_r_no_match)}));
+ for (int i = 0; i < (parallel ? batch_multiplicity_for_parallel : 1); ++i) {
+ expected_batches.push_back(prefix_batch);
+ }
+ if (expected_num_r_no_match > 0) {
+ ASSERT_OK_AND_ASSIGN(
+ auto suffix_batch,
+ ExecBatch::Make({expected_batch.values[0].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match),
+ expected_batch.values[1].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match),
+ expected_batch.values[2].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match),
+ expected_batch.values[3].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match)}));
+ expected_batches.push_back(suffix_batch);
+ }
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> expected,
+ TableFromExecBatches(output_schema, expected_batches));
+
+ // Compare results
+ AssertTablesEqual(expected, output);
+
+ // TODO: This was added for debugging. Remove in the final version.
+ // std::cout << output->ToString();
Review comment:
nit: remove this TODO?
##########
File path: cpp/src/arrow/compute/exec/hash_join_dict.h
##########
@@ -0,0 +1,321 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <unordered_map>
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/schema_util.h"
+#include "arrow/compute/kernels/row_encoder.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+
+// This file contains hash join logic related to handling of dictionary
encoded key
+// columns.
+//
+// A key column from probe side of the join can be matched against a key
column from build
+// side of the join, as long as the underlying value types are equal. That
means that:
+// - both scalars and arrays can be used and even mixed in the same column
+// - dictionary column can be matched against non-dictionary column if
underlying value
+// types are equal
+// - dictionary column can be matched against dictionary column with a
different index
+// type, and potentially using a different dictionary, if underlying value
types are equal
+//
+// We currently require in hash that for all dictionary encoded columns, the
same
+// dictionary is used in all input exec batches.
+//
+// In order to allow matching columns with different dictionaries, different
dictionary
+// index types, and dictionary key against non-dictionary key, internally
comparisons will
+// be evaluated after remapping values on both sides of the join to a common
+// representation (which will be called "unified representation"). This common
+// representation is a column of int32() type (not a dictionary column). It
represents an
+// index in the unified dictionary computed for the (only) dictionary present
on build
+// side (an empty dictionary is still created for an empty build side). Null
value is
+// always represented in this common representation as null int32 value,
unified
+// dictionary will never contain a null value (so there is no ambiguity of
representing
+// nulls as either index to a null entry in the dictionary or null index).
+//
+// Unified dictionary represents values present on build side. There may be
values on
+// probe side that are not present in it. All such values, that are not null,
are mapped
+// in the common representation to a special constant kMissingValueId.
+//
+
+namespace arrow {
+namespace compute {
+
+using internal::RowEncoder;
+
+/// Helper class with operations that are stateless and common to processing
of dictionary
+/// keys on both build and probe side.
+class HashJoinDictUtil {
+ public:
+ // Null values in unified representation are always represented as null that
has
+ // corresponding integer set to this constant
+ static constexpr int32_t kNullId = 0;
+ // Constant representing a value, that is not null, missing on the build
side, in
+ // unified representation.
+ static constexpr int32_t kMissingValueId = -1;
+
+ // Check if data types of corresponding pair of key column on build and
probe side are
+ // compatible
+ static bool KeyDataTypesValid(const std::shared_ptr<DataType>&
probe_data_type,
+ const std::shared_ptr<DataType>&
build_data_type);
+
+ // Input must be dictionary array or dictionary scalar.
+ // A precomputed and provided here lookup table in the form of int32() array
will be
+ // used to remap input indices to unified representation.
+ //
+ static Result<std::shared_ptr<ArrayData>> IndexRemapUsingLUT(
+ ExecContext* ctx, const Datum& indices, int64_t batch_length,
+ const std::shared_ptr<ArrayData>& map_array,
+ const std::shared_ptr<DataType>& data_type);
+
+ // Return int32() array that contains indices of input dictionary array or
scalar after
+ // type casting.
+ static Result<std::shared_ptr<ArrayData>> CvtToInt32(
+ const std::shared_ptr<DataType>& from_type, const Datum& input,
+ int64_t batch_length, ExecContext* ctx);
+
+ // Return an array that contains elements of input int32() array after
casting to a
+ // given integer type. This is used for mapping unified representation
stored in the
+ // hash table on build side back to original input data type of hash join,
when
+ // outputting hash join results to parent exec node.
+ //
+ static Result<std::shared_ptr<ArrayData>> CvtFromInt32(
+ const std::shared_ptr<DataType>& to_type, const Datum& input, int64_t
batch_length,
+ ExecContext* ctx);
+
+ // Return dictionary referenced in either dictionary array or dictionary
scalar
+ static std::shared_ptr<Array> ExtractDictionary(const Datum& data);
+
+ private:
+ template <typename FROM, typename TO>
+ static Result<std::shared_ptr<ArrayData>> CvtImp(
+ const std::shared_ptr<DataType>& to_type, const Datum& input, int64_t
batch_length,
+ ExecContext* ctx);
+};
+
+/// Implements processing of dictionary arrays/scalars in key columns on the
build side of
+/// a hash join.
+/// Each instance of this class corresponds to a single column and stores and
+/// processes only the information related to that column.
+/// Const methods are thread-safe, non-const methods are not (the caller must
make sure
+/// that only one thread at any time will access them).
+///
Review comment:
FWIW, thanks for the detailed comments in this file - they help a lot in
understanding what's going on here.
##########
File path: cpp/src/arrow/compute/exec/hash_join_node_test.cc
##########
@@ -1113,5 +1113,539 @@ TEST(HashJoin, Random) {
}
}
+void DecodeScalarsAndDictionariesInBatch(ExecBatch* batch, MemoryPool* pool) {
+ for (size_t i = 0; i < batch->values.size(); ++i) {
+ if (batch->values[i].is_scalar()) {
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<Array> col,
+ MakeArrayFromScalar(*(batch->values[i].scalar()), batch->length,
pool));
+ batch->values[i] = Datum(col);
+ }
+ if (batch->values[i].type()->id() == Type::DICTIONARY) {
+ const auto& dict_type =
+ checked_cast<const DictionaryType&>(*batch->values[i].type());
+ std::shared_ptr<ArrayData> indices =
+ ArrayData::Make(dict_type.index_type(),
batch->values[i].array()->length,
+ batch->values[i].array()->buffers);
+ const std::shared_ptr<ArrayData>& dictionary =
batch->values[i].array()->dictionary;
+ ASSERT_OK_AND_ASSIGN(Datum col, Take(*dictionary, *indices));
+ batch->values[i] = col;
+ }
+ }
+}
+
+std::shared_ptr<Schema> UpdateSchemaAfterDecodingDictionaries(
+ const std::shared_ptr<Schema>& schema) {
+ std::vector<std::shared_ptr<Field>> output_fields(schema->num_fields());
+ for (int i = 0; i < schema->num_fields(); ++i) {
+ const std::shared_ptr<Field>& field = schema->field(i);
+ if (field->type()->id() == Type::DICTIONARY) {
+ const auto& dict_type = checked_cast<const
DictionaryType&>(*field->type());
+ output_fields[i] = std::make_shared<Field>(field->name(),
dict_type.value_type(),
+ true /* nullable */);
+ } else {
+ output_fields[i] = field->Copy();
+ }
+ }
+ return std::make_shared<Schema>(std::move(output_fields));
+}
+
+void TestHashJoinDictionaryHelper(
+ JoinType join_type, JoinKeyCmp cmp,
+ // Whether to run parallel hash join.
+ // This requires generating multiple copies of each input batch on one
side of the
+ // join. Expected results will be automatically adjusted to reflect the
multiplication
+ // of input batches.
+ bool parallel, Datum l_key, Datum l_payload, Datum r_key, Datum r_payload,
+ Datum l_out_key, Datum l_out_payload, Datum r_out_key, Datum r_out_payload,
+ // Number of rows at the end of expected output that represent rows from
the right
+ // side that do not have a match on the left side. This number is needed to
+ // automatically adjust expected result when multiplying input batches on
the left
+ // side.
+ int expected_num_r_no_match,
+ // Whether to swap two inputs to the hash join
+ bool swap_sides) {
+ int64_t l_length = l_key.is_array()
+ ? l_key.array()->length
+ : l_payload.is_array() ? l_payload.array()->length :
-1;
+ int64_t r_length = r_key.is_array()
+ ? r_key.array()->length
+ : r_payload.is_array() ? r_payload.array()->length :
-1;
+ ARROW_DCHECK(l_length >= 0 && r_length >= 0);
+
+ constexpr int batch_multiplicity_for_parallel = 2;
+
+ // Split both sides into exactly two batches
+ int64_t l_first_length = l_length / 2;
+ int64_t r_first_length = r_length / 2;
+ BatchesWithSchema l_batches, r_batches;
+ l_batches.batches.resize(2);
+ r_batches.batches.resize(2);
+ ASSERT_OK_AND_ASSIGN(
+ l_batches.batches[0],
+ ExecBatch::Make({l_key.is_array() ? l_key.array()->Slice(0,
l_first_length) : l_key,
+ l_payload.is_array() ? l_payload.array()->Slice(0,
l_first_length)
+ : l_payload}));
+ ASSERT_OK_AND_ASSIGN(
+ l_batches.batches[1],
+ ExecBatch::Make(
+ {l_key.is_array()
+ ? l_key.array()->Slice(l_first_length, l_length -
l_first_length)
+ : l_key,
+ l_payload.is_array()
+ ? l_payload.array()->Slice(l_first_length, l_length -
l_first_length)
+ : l_payload}));
+ ASSERT_OK_AND_ASSIGN(
+ r_batches.batches[0],
+ ExecBatch::Make({r_key.is_array() ? r_key.array()->Slice(0,
r_first_length) : r_key,
+ r_payload.is_array() ? r_payload.array()->Slice(0,
r_first_length)
+ : r_payload}));
+ ASSERT_OK_AND_ASSIGN(
+ r_batches.batches[1],
+ ExecBatch::Make(
+ {r_key.is_array()
+ ? r_key.array()->Slice(r_first_length, r_length -
r_first_length)
+ : r_key,
+ r_payload.is_array()
+ ? r_payload.array()->Slice(r_first_length, r_length -
r_first_length)
+ : r_payload}));
+ l_batches.schema =
+ schema({field("l_key", l_key.type()), field("l_payload",
l_payload.type())});
+ r_batches.schema =
+ schema({field("r_key", r_key.type()), field("r_payload",
r_payload.type())});
+
+ // Add copies of input batches on originally left side of the hash join
+ if (parallel) {
+ for (int i = 0; i < batch_multiplicity_for_parallel - 1; ++i) {
+ l_batches.batches.push_back(l_batches.batches[0]);
+ l_batches.batches.push_back(l_batches.batches[1]);
+ }
+ }
+
+ auto exec_ctx = arrow::internal::make_unique<ExecContext>(
+ default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() :
nullptr);
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
+ ASSERT_OK_AND_ASSIGN(
+ ExecNode * l_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{l_batches.schema, l_batches.gen(parallel,
+
/*slow=*/false)}));
+ ASSERT_OK_AND_ASSIGN(
+ ExecNode * r_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{r_batches.schema, r_batches.gen(parallel,
+
/*slow=*/false)}));
+ HashJoinNodeOptions join_options{join_type,
+ {FieldRef(swap_sides ? "r_key" : "l_key")},
+ {FieldRef(swap_sides ? "l_key" : "r_key")},
+ {FieldRef(swap_sides ? "r_key" : "l_key"),
+ FieldRef(swap_sides ? "r_payload" :
"l_payload")},
+ {FieldRef(swap_sides ? "l_key" : "r_key"),
+ FieldRef(swap_sides ? "l_payload" :
"r_payload")},
+ {cmp}};
+ ASSERT_OK_AND_ASSIGN(ExecNode * join, MakeExecNode("hashjoin", plan.get(),
+ {(swap_sides ? r_source :
l_source),
+ (swap_sides ? l_source :
r_source)},
+ join_options));
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+ ASSERT_OK_AND_ASSIGN(
+ std::ignore, MakeExecNode("sink", plan.get(), {join},
SinkNodeOptions{&sink_gen}));
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(),
sink_gen));
+
+ for (auto& batch : res) {
+ DecodeScalarsAndDictionariesInBatch(&batch, exec_ctx->memory_pool());
+ }
+ std::shared_ptr<Schema> output_schema =
+ UpdateSchemaAfterDecodingDictionaries(join->output_schema());
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> output,
+ TableFromExecBatches(output_schema, res));
+
+ ExecBatch expected_batch;
+ if (swap_sides) {
+ ASSERT_OK_AND_ASSIGN(expected_batch, ExecBatch::Make({r_out_key,
r_out_payload,
+ l_out_key,
l_out_payload}));
+ } else {
+ ASSERT_OK_AND_ASSIGN(expected_batch, ExecBatch::Make({l_out_key,
l_out_payload,
+ r_out_key,
r_out_payload}));
+ }
+
+ DecodeScalarsAndDictionariesInBatch(&expected_batch,
exec_ctx->memory_pool());
+
+ // Slice expected batch into two to separate rows on right side with no
matches from
+ // everything else.
+ //
+ std::vector<ExecBatch> expected_batches;
+ ASSERT_OK_AND_ASSIGN(
+ auto prefix_batch,
+ ExecBatch::Make({expected_batch.values[0].array()->Slice(
+ 0, expected_batch.length - expected_num_r_no_match),
+ expected_batch.values[1].array()->Slice(
+ 0, expected_batch.length - expected_num_r_no_match),
+ expected_batch.values[2].array()->Slice(
+ 0, expected_batch.length - expected_num_r_no_match),
+ expected_batch.values[3].array()->Slice(
+ 0, expected_batch.length -
expected_num_r_no_match)}));
+ for (int i = 0; i < (parallel ? batch_multiplicity_for_parallel : 1); ++i) {
+ expected_batches.push_back(prefix_batch);
+ }
+ if (expected_num_r_no_match > 0) {
+ ASSERT_OK_AND_ASSIGN(
+ auto suffix_batch,
+ ExecBatch::Make({expected_batch.values[0].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match),
+ expected_batch.values[1].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match),
+ expected_batch.values[2].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match),
+ expected_batch.values[3].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match)}));
+ expected_batches.push_back(suffix_batch);
+ }
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> expected,
+ TableFromExecBatches(output_schema, expected_batches));
+
+ // Compare results
+ AssertTablesEqual(expected, output);
+
+ // TODO: This was added for debugging. Remove in the final version.
+ // std::cout << output->ToString();
+}
+
+TEST(HashJoin, Dictionary) {
+ auto int8_utf8 = std::make_shared<DictionaryType>(int8(), utf8());
Review comment:
n.b. arrow::dictionary(index_ty, value_ty) is shorthand for this
##########
File path: cpp/src/arrow/compute/exec/hash_join_dict.h
##########
@@ -0,0 +1,321 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <unordered_map>
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/schema_util.h"
+#include "arrow/compute/kernels/row_encoder.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+
+// This file contains hash join logic related to handling of dictionary
encoded key
+// columns.
+//
+// A key column from probe side of the join can be matched against a key
column from build
+// side of the join, as long as the underlying value types are equal. That
means that:
+// - both scalars and arrays can be used and even mixed in the same column
+// - dictionary column can be matched against non-dictionary column if
underlying value
+// types are equal
+// - dictionary column can be matched against dictionary column with a
different index
+// type, and potentially using a different dictionary, if underlying value
types are equal
+//
+// We currently require in hash that for all dictionary encoded columns, the
same
+// dictionary is used in all input exec batches.
+//
+// In order to allow matching columns with different dictionaries, different
dictionary
+// index types, and dictionary key against non-dictionary key, internally
comparisons will
+// be evaluated after remapping values on both sides of the join to a common
+// representation (which will be called "unified representation"). This common
+// representation is a column of int32() type (not a dictionary column). It
represents an
+// index in the unified dictionary computed for the (only) dictionary present
on build
+// side (an empty dictionary is still created for an empty build side). Null
value is
+// always represented in this common representation as null int32 value,
unified
+// dictionary will never contain a null value (so there is no ambiguity of
representing
+// nulls as either index to a null entry in the dictionary or null index).
+//
+// Unified dictionary represents values present on build side. There may be
values on
+// probe side that are not present in it. All such values, that are not null,
are mapped
+// in the common representation to a special constant kMissingValueId.
+//
+
+namespace arrow {
+namespace compute {
+
+using internal::RowEncoder;
+
+/// Helper class with operations that are stateless and common to processing
of dictionary
+/// keys on both build and probe side.
+class HashJoinDictUtil {
+ public:
+ // Null values in unified representation are always represented as null that
has
+ // corresponding integer set to this constant
+ static constexpr int32_t kNullId = 0;
+ // Constant representing a value, that is not null, missing on the build
side, in
+ // unified representation.
+ static constexpr int32_t kMissingValueId = -1;
+
+ // Check if data types of corresponding pair of key column on build and
probe side are
+ // compatible
+ static bool KeyDataTypesValid(const std::shared_ptr<DataType>&
probe_data_type,
+ const std::shared_ptr<DataType>&
build_data_type);
+
+ // Input must be dictionary array or dictionary scalar.
+ // A precomputed and provided here lookup table in the form of int32() array
will be
+ // used to remap input indices to unified representation.
+ //
+ static Result<std::shared_ptr<ArrayData>> IndexRemapUsingLUT(
+ ExecContext* ctx, const Datum& indices, int64_t batch_length,
+ const std::shared_ptr<ArrayData>& map_array,
+ const std::shared_ptr<DataType>& data_type);
+
+ // Return int32() array that contains indices of input dictionary array or
scalar after
+ // type casting.
+ static Result<std::shared_ptr<ArrayData>> CvtToInt32(
Review comment:
I might be dense but what does "Cvt" stand for?
##########
File path: cpp/src/arrow/compute/exec/hash_join_dict.h
##########
@@ -0,0 +1,321 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <unordered_map>
+
+#include "arrow/compute/exec.h"
+#include "arrow/compute/exec/schema_util.h"
+#include "arrow/compute/kernels/row_encoder.h"
+#include "arrow/result.h"
+#include "arrow/status.h"
+#include "arrow/type.h"
+
+// This file contains hash join logic related to handling of dictionary
encoded key
+// columns.
+//
+// A key column from probe side of the join can be matched against a key
column from build
+// side of the join, as long as the underlying value types are equal. That
means that:
+// - both scalars and arrays can be used and even mixed in the same column
+// - dictionary column can be matched against non-dictionary column if
underlying value
+// types are equal
+// - dictionary column can be matched against dictionary column with a
different index
+// type, and potentially using a different dictionary, if underlying value
types are equal
+//
+// We currently require in hash that for all dictionary encoded columns, the
same
+// dictionary is used in all input exec batches.
+//
+// In order to allow matching columns with different dictionaries, different
dictionary
+// index types, and dictionary key against non-dictionary key, internally
comparisons will
+// be evaluated after remapping values on both sides of the join to a common
+// representation (which will be called "unified representation"). This common
+// representation is a column of int32() type (not a dictionary column). It
represents an
+// index in the unified dictionary computed for the (only) dictionary present
on build
+// side (an empty dictionary is still created for an empty build side). Null
value is
+// always represented in this common representation as null int32 value,
unified
+// dictionary will never contain a null value (so there is no ambiguity of
representing
+// nulls as either index to a null entry in the dictionary or null index).
+//
+// Unified dictionary represents values present on build side. There may be
values on
+// probe side that are not present in it. All such values, that are not null,
are mapped
+// in the common representation to a special constant kMissingValueId.
+//
+
+namespace arrow {
+namespace compute {
+
+using internal::RowEncoder;
+
+/// Helper class with operations that are stateless and common to processing
of dictionary
+/// keys on both build and probe side.
+class HashJoinDictUtil {
+ public:
+ // Null values in unified representation are always represented as null that
has
+ // corresponding integer set to this constant
+ static constexpr int32_t kNullId = 0;
+ // Constant representing a value, that is not null, missing on the build
side, in
+ // unified representation.
+ static constexpr int32_t kMissingValueId = -1;
+
+ // Check if data types of corresponding pair of key column on build and
probe side are
+ // compatible
+ static bool KeyDataTypesValid(const std::shared_ptr<DataType>&
probe_data_type,
+ const std::shared_ptr<DataType>&
build_data_type);
+
+ // Input must be dictionary array or dictionary scalar.
+ // A precomputed and provided here lookup table in the form of int32() array
will be
+ // used to remap input indices to unified representation.
+ //
+ static Result<std::shared_ptr<ArrayData>> IndexRemapUsingLUT(
+ ExecContext* ctx, const Datum& indices, int64_t batch_length,
+ const std::shared_ptr<ArrayData>& map_array,
+ const std::shared_ptr<DataType>& data_type);
+
+ // Return int32() array that contains indices of input dictionary array or
scalar after
+ // type casting.
+ static Result<std::shared_ptr<ArrayData>> CvtToInt32(
+ const std::shared_ptr<DataType>& from_type, const Datum& input,
+ int64_t batch_length, ExecContext* ctx);
+
+ // Return an array that contains elements of input int32() array after
casting to a
+ // given integer type. This is used for mapping unified representation
stored in the
+ // hash table on build side back to original input data type of hash join,
when
+ // outputting hash join results to parent exec node.
+ //
+ static Result<std::shared_ptr<ArrayData>> CvtFromInt32(
+ const std::shared_ptr<DataType>& to_type, const Datum& input, int64_t
batch_length,
+ ExecContext* ctx);
+
+ // Return dictionary referenced in either dictionary array or dictionary
scalar
+ static std::shared_ptr<Array> ExtractDictionary(const Datum& data);
+
+ private:
+ template <typename FROM, typename TO>
+ static Result<std::shared_ptr<ArrayData>> CvtImp(
+ const std::shared_ptr<DataType>& to_type, const Datum& input, int64_t
batch_length,
+ ExecContext* ctx);
Review comment:
Does this need to be in the header at all then? It seems it could live
solely in the .cc file.
##########
File path: cpp/src/arrow/compute/exec/hash_join_node_test.cc
##########
@@ -1113,5 +1113,539 @@ TEST(HashJoin, Random) {
}
}
+void DecodeScalarsAndDictionariesInBatch(ExecBatch* batch, MemoryPool* pool) {
+ for (size_t i = 0; i < batch->values.size(); ++i) {
+ if (batch->values[i].is_scalar()) {
+ ASSERT_OK_AND_ASSIGN(
+ std::shared_ptr<Array> col,
+ MakeArrayFromScalar(*(batch->values[i].scalar()), batch->length,
pool));
+ batch->values[i] = Datum(col);
+ }
+ if (batch->values[i].type()->id() == Type::DICTIONARY) {
+ const auto& dict_type =
+ checked_cast<const DictionaryType&>(*batch->values[i].type());
+ std::shared_ptr<ArrayData> indices =
+ ArrayData::Make(dict_type.index_type(),
batch->values[i].array()->length,
+ batch->values[i].array()->buffers);
+ const std::shared_ptr<ArrayData>& dictionary =
batch->values[i].array()->dictionary;
+ ASSERT_OK_AND_ASSIGN(Datum col, Take(*dictionary, *indices));
+ batch->values[i] = col;
+ }
+ }
+}
+
+std::shared_ptr<Schema> UpdateSchemaAfterDecodingDictionaries(
+ const std::shared_ptr<Schema>& schema) {
+ std::vector<std::shared_ptr<Field>> output_fields(schema->num_fields());
+ for (int i = 0; i < schema->num_fields(); ++i) {
+ const std::shared_ptr<Field>& field = schema->field(i);
+ if (field->type()->id() == Type::DICTIONARY) {
+ const auto& dict_type = checked_cast<const
DictionaryType&>(*field->type());
+ output_fields[i] = std::make_shared<Field>(field->name(),
dict_type.value_type(),
+ true /* nullable */);
+ } else {
+ output_fields[i] = field->Copy();
+ }
+ }
+ return std::make_shared<Schema>(std::move(output_fields));
+}
+
+void TestHashJoinDictionaryHelper(
+ JoinType join_type, JoinKeyCmp cmp,
+ // Whether to run parallel hash join.
+ // This requires generating multiple copies of each input batch on one
side of the
+ // join. Expected results will be automatically adjusted to reflect the
multiplication
+ // of input batches.
+ bool parallel, Datum l_key, Datum l_payload, Datum r_key, Datum r_payload,
+ Datum l_out_key, Datum l_out_payload, Datum r_out_key, Datum r_out_payload,
+ // Number of rows at the end of expected output that represent rows from
the right
+ // side that do not have a match on the left side. This number is needed to
+ // automatically adjust expected result when multiplying input batches on
the left
+ // side.
+ int expected_num_r_no_match,
+ // Whether to swap two inputs to the hash join
+ bool swap_sides) {
+ int64_t l_length = l_key.is_array()
+ ? l_key.array()->length
+ : l_payload.is_array() ? l_payload.array()->length :
-1;
+ int64_t r_length = r_key.is_array()
+ ? r_key.array()->length
+ : r_payload.is_array() ? r_payload.array()->length :
-1;
+ ARROW_DCHECK(l_length >= 0 && r_length >= 0);
+
+ constexpr int batch_multiplicity_for_parallel = 2;
+
+ // Split both sides into exactly two batches
+ int64_t l_first_length = l_length / 2;
+ int64_t r_first_length = r_length / 2;
+ BatchesWithSchema l_batches, r_batches;
+ l_batches.batches.resize(2);
+ r_batches.batches.resize(2);
+ ASSERT_OK_AND_ASSIGN(
+ l_batches.batches[0],
+ ExecBatch::Make({l_key.is_array() ? l_key.array()->Slice(0,
l_first_length) : l_key,
+ l_payload.is_array() ? l_payload.array()->Slice(0,
l_first_length)
+ : l_payload}));
+ ASSERT_OK_AND_ASSIGN(
+ l_batches.batches[1],
+ ExecBatch::Make(
+ {l_key.is_array()
+ ? l_key.array()->Slice(l_first_length, l_length -
l_first_length)
+ : l_key,
+ l_payload.is_array()
+ ? l_payload.array()->Slice(l_first_length, l_length -
l_first_length)
+ : l_payload}));
+ ASSERT_OK_AND_ASSIGN(
+ r_batches.batches[0],
+ ExecBatch::Make({r_key.is_array() ? r_key.array()->Slice(0,
r_first_length) : r_key,
+ r_payload.is_array() ? r_payload.array()->Slice(0,
r_first_length)
+ : r_payload}));
+ ASSERT_OK_AND_ASSIGN(
+ r_batches.batches[1],
+ ExecBatch::Make(
+ {r_key.is_array()
+ ? r_key.array()->Slice(r_first_length, r_length -
r_first_length)
+ : r_key,
+ r_payload.is_array()
+ ? r_payload.array()->Slice(r_first_length, r_length -
r_first_length)
+ : r_payload}));
+ l_batches.schema =
+ schema({field("l_key", l_key.type()), field("l_payload",
l_payload.type())});
+ r_batches.schema =
+ schema({field("r_key", r_key.type()), field("r_payload",
r_payload.type())});
+
+ // Add copies of input batches on originally left side of the hash join
+ if (parallel) {
+ for (int i = 0; i < batch_multiplicity_for_parallel - 1; ++i) {
+ l_batches.batches.push_back(l_batches.batches[0]);
+ l_batches.batches.push_back(l_batches.batches[1]);
+ }
+ }
+
+ auto exec_ctx = arrow::internal::make_unique<ExecContext>(
+ default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() :
nullptr);
+ ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
+ ASSERT_OK_AND_ASSIGN(
+ ExecNode * l_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{l_batches.schema, l_batches.gen(parallel,
+
/*slow=*/false)}));
+ ASSERT_OK_AND_ASSIGN(
+ ExecNode * r_source,
+ MakeExecNode("source", plan.get(), {},
+ SourceNodeOptions{r_batches.schema, r_batches.gen(parallel,
+
/*slow=*/false)}));
+ HashJoinNodeOptions join_options{join_type,
+ {FieldRef(swap_sides ? "r_key" : "l_key")},
+ {FieldRef(swap_sides ? "l_key" : "r_key")},
+ {FieldRef(swap_sides ? "r_key" : "l_key"),
+ FieldRef(swap_sides ? "r_payload" :
"l_payload")},
+ {FieldRef(swap_sides ? "l_key" : "r_key"),
+ FieldRef(swap_sides ? "l_payload" :
"r_payload")},
+ {cmp}};
+ ASSERT_OK_AND_ASSIGN(ExecNode * join, MakeExecNode("hashjoin", plan.get(),
+ {(swap_sides ? r_source :
l_source),
+ (swap_sides ? l_source :
r_source)},
+ join_options));
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+ ASSERT_OK_AND_ASSIGN(
+ std::ignore, MakeExecNode("sink", plan.get(), {join},
SinkNodeOptions{&sink_gen}));
+ ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(),
sink_gen));
+
+ for (auto& batch : res) {
+ DecodeScalarsAndDictionariesInBatch(&batch, exec_ctx->memory_pool());
+ }
+ std::shared_ptr<Schema> output_schema =
+ UpdateSchemaAfterDecodingDictionaries(join->output_schema());
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> output,
+ TableFromExecBatches(output_schema, res));
+
+ ExecBatch expected_batch;
+ if (swap_sides) {
+ ASSERT_OK_AND_ASSIGN(expected_batch, ExecBatch::Make({r_out_key,
r_out_payload,
+ l_out_key,
l_out_payload}));
+ } else {
+ ASSERT_OK_AND_ASSIGN(expected_batch, ExecBatch::Make({l_out_key,
l_out_payload,
+ r_out_key,
r_out_payload}));
+ }
+
+ DecodeScalarsAndDictionariesInBatch(&expected_batch,
exec_ctx->memory_pool());
+
+ // Slice expected batch into two to separate rows on right side with no
matches from
+ // everything else.
+ //
+ std::vector<ExecBatch> expected_batches;
+ ASSERT_OK_AND_ASSIGN(
+ auto prefix_batch,
+ ExecBatch::Make({expected_batch.values[0].array()->Slice(
+ 0, expected_batch.length - expected_num_r_no_match),
+ expected_batch.values[1].array()->Slice(
+ 0, expected_batch.length - expected_num_r_no_match),
+ expected_batch.values[2].array()->Slice(
+ 0, expected_batch.length - expected_num_r_no_match),
+ expected_batch.values[3].array()->Slice(
+ 0, expected_batch.length -
expected_num_r_no_match)}));
+ for (int i = 0; i < (parallel ? batch_multiplicity_for_parallel : 1); ++i) {
+ expected_batches.push_back(prefix_batch);
+ }
+ if (expected_num_r_no_match > 0) {
+ ASSERT_OK_AND_ASSIGN(
+ auto suffix_batch,
+ ExecBatch::Make({expected_batch.values[0].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match),
+ expected_batch.values[1].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match),
+ expected_batch.values[2].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match),
+ expected_batch.values[3].array()->Slice(
+ expected_batch.length - expected_num_r_no_match,
+ expected_num_r_no_match)}));
+ expected_batches.push_back(suffix_batch);
+ }
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> expected,
+ TableFromExecBatches(output_schema, expected_batches));
+
+ // Compare results
+ AssertTablesEqual(expected, output);
+
+ // TODO: This was added for debugging. Remove in the final version.
+ // std::cout << output->ToString();
+}
+
+TEST(HashJoin, Dictionary) {
+ auto int8_utf8 = std::make_shared<DictionaryType>(int8(), utf8());
+ auto uint8_utf8 = std::make_shared<DictionaryType>(uint8(), utf8());
+ auto int16_utf8 = std::make_shared<DictionaryType>(int16(), utf8());
+ auto uint16_utf8 = std::make_shared<DictionaryType>(uint16(), utf8());
+ auto int32_utf8 = std::make_shared<DictionaryType>(int32(), utf8());
+ auto uint32_utf8 = std::make_shared<DictionaryType>(uint32(), utf8());
+ auto int64_utf8 = std::make_shared<DictionaryType>(int64(), utf8());
+ auto uint64_utf8 = std::make_shared<DictionaryType>(uint64(), utf8());
+ std::shared_ptr<DictionaryType> dict_types[] = {int8_utf8, uint8_utf8,
int16_utf8,
+ uint16_utf8, int32_utf8,
uint32_utf8,
+ int64_utf8, uint64_utf8};
+
+ Random64Bit rng(43);
+
+ // Dictionaries in payload columns
+ for (auto parallel : {false, true})
Review comment:
nit: can we insert braces here for readability?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]