rtpsw commented on code in PR #13880:
URL: https://github.com/apache/arrow/pull/13880#discussion_r962021188
##########
cpp/src/arrow/compute/exec/asof_join_node_test.cc:
##########
@@ -74,237 +241,784 @@ void CheckRunOutput(const BatchesWithSchema& l_batches,
/*same_chunk_layout=*/true, /*flatten=*/true);
}
-void DoRunBasicTest(const std::vector<util::string_view>& l_data,
- const std::vector<util::string_view>& r0_data,
- const std::vector<util::string_view>& r1_data,
- const std::vector<util::string_view>& exp_data, int64_t
tolerance) {
- auto l_schema =
- schema({field("time", int64()), field("key", int32()), field("l_v0",
float64())});
- auto r0_schema =
- schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())});
- auto r1_schema =
- schema({field("time", int64()), field("key", int32()), field("r1_v0",
float32())});
-
- auto exp_schema = schema({
- field("time", int64()),
- field("key", int32()),
- field("l_v0", float64()),
- field("r0_v0", float64()),
- field("r1_v0", float32()),
- });
-
- // Test three table join
- BatchesWithSchema l_batches, r0_batches, r1_batches, exp_batches;
- l_batches = MakeBatchesFromString(l_schema, l_data);
- r0_batches = MakeBatchesFromString(r0_schema, r0_data);
- r1_batches = MakeBatchesFromString(r1_schema, r1_data);
- exp_batches = MakeBatchesFromString(exp_schema, exp_data);
- CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key",
- tolerance);
-}
+#define CHECK_RUN_OUTPUT(by_key_type)
\
+ void CheckRunOutput(
\
+ const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches,
\
+ const BatchesWithSchema& r1_batches, const BatchesWithSchema&
exp_batches, \
+ const FieldRef time, by_key_type key, const int64_t tolerance) {
\
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches,
\
+ AsofJoinNodeOptions(time, {key}, tolerance));
\
+ }
-void DoRunInvalidTypeTest(const std::shared_ptr<Schema>& l_schema,
- const std::shared_ptr<Schema>& r_schema) {
- BatchesWithSchema l_batches = MakeBatchesFromString(l_schema, {R"([])"});
- BatchesWithSchema r_batches = MakeBatchesFromString(r_schema, {R"([])"});
+EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT)
+void DoInvalidPlanTest(const BatchesWithSchema& l_batches,
+ const BatchesWithSchema& r_batches,
+ const AsofJoinNodeOptions& join_options,
+ const std::string& expected_error_str,
+ bool fail_on_plan_creation = false) {
ExecContext exec_ctx;
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx));
- AsofJoinNodeOptions join_options("time", "key", 0);
Declaration join{"asofjoin", join_options};
join.inputs.emplace_back(Declaration{
"source", SourceNodeOptions{l_batches.schema, l_batches.gen(false,
false)}});
join.inputs.emplace_back(Declaration{
"source", SourceNodeOptions{r_batches.schema, r_batches.gen(false,
false)}});
- ASSERT_RAISES(Invalid, join.AddToPlan(plan.get()));
+ if (fail_on_plan_creation) {
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+ ASSERT_OK(Declaration::Sequence({join, {"sink",
SinkNodeOptions{&sink_gen}}})
+ .AddToPlan(plan.get()));
+ EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(Invalid,
+
::testing::HasSubstr(expected_error_str),
+ StartAndCollect(plan.get(),
sink_gen));
+ } else {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr(expected_error_str),
+ join.AddToPlan(plan.get()));
+ }
+}
+
+void DoRunInvalidPlanTest(const BatchesWithSchema& l_batches,
+ const BatchesWithSchema& r_batches,
+ const AsofJoinNodeOptions& join_options,
+ const std::string& expected_error_str) {
+ DoInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str);
+}
+
+void DoRunInvalidPlanTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema,
+ const AsofJoinNodeOptions& join_options,
+ const std::string& expected_error_str) {
+ ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema,
{R"([])"}));
+ ASSERT_OK_AND_ASSIGN(auto r_batches, MakeBatchesFromNumString(r_schema,
{R"([])"}));
+
+ return DoRunInvalidPlanTest(l_batches, r_batches, join_options,
expected_error_str);
+}
+
+void DoRunInvalidPlanTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema, int64_t
tolerance,
+ const std::string& expected_error_str) {
+ DoRunInvalidPlanTest(l_schema, r_schema,
+ AsofJoinNodeOptions("time", {"key"}, tolerance),
+ expected_error_str);
+}
+
+void DoRunInvalidTypeTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, 0, "Unsupported type for ");
+}
+
+void DoRunInvalidToleranceTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, -1,
+ "AsOfJoin tolerance must be non-negative but is ");
}
+void DoRunMissingKeysTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : No
match");
+}
+
+void DoRunMissingOnKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema,
+ AsofJoinNodeOptions("invalid_time", {"key"}, 0),
+ "Bad join key on table : No match");
+}
+
+void DoRunMissingByKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema,
+ AsofJoinNodeOptions("time", {"invalid_key"}, 0),
+ "Bad join key on table : No match");
+}
+
+void DoRunNestedOnKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions({0, "time"},
{"key"}, 0),
+ "Bad join key on table : No match");
+}
+
+void DoRunNestedByKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema,
+ AsofJoinNodeOptions("time", {FieldRef{0, 1}}, 0),
+ "Bad join key on table : No match");
+}
+
+void DoRunAmbiguousOnKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table :
Multiple matches");
+}
+
+void DoRunAmbiguousByKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table :
Multiple matches");
+}
+
+// Gets a batch for testing as a Json string
+// The batch will have n_rows rows n_cols columns, the first column being the
on-field
+// If unordered is true then the first column will be out-of-order
+std::string GetTestBatchAsJsonString(int n_rows, int n_cols, bool unordered =
false) {
+ int order_mask = unordered ? 1 : 0;
+ std::stringstream s;
+ s << '[';
+ for (int i = 0; i < n_rows; i++) {
+ if (i > 0) {
+ s << ", ";
+ }
+ s << '[';
+ for (int j = 0; j < n_cols; j++) {
+ if (j > 0) {
+ s << ", " << j;
+ } else if (j < 2) {
+ s << (i ^ order_mask);
+ } else {
+ s << i;
+ }
+ }
+ s << ']';
+ }
+ s << ']';
+ return s.str();
+}
+
+void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered,
+ const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema,
+ const AsofJoinNodeOptions& join_options,
+ const std::string& expected_error_str) {
+ ASSERT_TRUE(l_unordered || r_unordered);
+ int n_rows = 5;
+ auto l_str = GetTestBatchAsJsonString(n_rows, l_schema->num_fields(),
l_unordered);
+ auto r_str = GetTestBatchAsJsonString(n_rows, r_schema->num_fields(),
r_unordered);
+ ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema,
{l_str}));
+ ASSERT_OK_AND_ASSIGN(auto r_batches, MakeBatchesFromNumString(r_schema,
{r_str}));
+
+ return DoInvalidPlanTest(l_batches, r_batches, join_options,
expected_error_str,
+ /*then_run_plan=*/true);
+}
+
+void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered,
+ const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunUnorderedPlanTest(l_unordered, r_unordered, l_schema, r_schema,
+ AsofJoinNodeOptions("time", {"key"}, 1000),
+ "out-of-order on-key values");
+}
+
+struct BasicTestTypes {
+ std::shared_ptr<DataType> time, key, l_val, r0_val, r1_val;
+};
+
+struct BasicTest {
+ BasicTest(const std::vector<util::string_view>& l_data,
+ const std::vector<util::string_view>& r0_data,
+ const std::vector<util::string_view>& r1_data,
+ const std::vector<util::string_view>& exp_nokey_data,
+ const std::vector<util::string_view>& exp_emptykey_data,
+ const std::vector<util::string_view>& exp_data, int64_t tolerance)
+ : l_data(std::move(l_data)),
+ r0_data(std::move(r0_data)),
+ r1_data(std::move(r1_data)),
+ exp_nokey_data(std::move(exp_nokey_data)),
+ exp_emptykey_data(std::move(exp_emptykey_data)),
+ exp_data(std::move(exp_data)),
+ tolerance(tolerance) {}
+
+ static inline void check_init(const std::vector<std::shared_ptr<DataType>>&
types) {
+ ASSERT_NE(0, types.size());
+ }
+
+ template <typename TypeCond>
+ static inline std::vector<std::shared_ptr<DataType>> init_types(
+ const std::vector<std::shared_ptr<DataType>>& all_types, TypeCond
type_cond) {
+ std::vector<std::shared_ptr<DataType>> types;
+ for (auto type : all_types) {
+ if (type_cond(type)) {
+ types.push_back(type);
+ }
+ }
+ check_init(types);
+ return types;
+ }
+
+ void RunSingleByKey() {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_emptykey_batches, B exp_batches) {
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time",
"key",
+ tolerance);
+ });
+ }
+ static void DoSingleByKey(BasicTest& basic_tests) {
basic_tests.RunSingleByKey(); }
+ void RunDoubleByKey() {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_emptykey_batches, B exp_batches) {
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time",
+ {"key", "key"}, tolerance);
+ });
+ }
+ static void DoDoubleByKey(BasicTest& basic_tests) {
basic_tests.RunDoubleByKey(); }
+ void RunMutateByKey() {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_emptykey_batches, B exp_batches) {
+ ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2"));
+ ASSERT_OK_AND_ASSIGN(r0_batches, MutateByKey(r0_batches, "key", "key2"));
+ ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key2"));
+ ASSERT_OK_AND_ASSIGN(exp_batches, MutateByKey(exp_batches, "key",
"key2"));
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time",
+ {"key", "key2"}, tolerance);
+ });
+ }
+ static void DoMutateByKey(BasicTest& basic_tests) {
basic_tests.RunMutateByKey(); }
+ void RunMutateNoKey() {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_emptykey_batches, B exp_batches) {
+ ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2",
true));
+ ASSERT_OK_AND_ASSIGN(r0_batches, MutateByKey(r0_batches, "key", "key2",
true));
+ ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key2",
true));
+ ASSERT_OK_AND_ASSIGN(exp_nokey_batches,
+ MutateByKey(exp_nokey_batches, "key", "key2",
true));
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches,
"time", "key2",
+ tolerance);
+ });
+ }
+ static void DoMutateNoKey(BasicTest& basic_tests) {
basic_tests.RunMutateNoKey(); }
+ void RunMutateNullKey() {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_emptykey_batches, B exp_batches) {
+ ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2",
true, true));
+ ASSERT_OK_AND_ASSIGN(r0_batches,
+ MutateByKey(r0_batches, "key", "key2", true, true));
+ ASSERT_OK_AND_ASSIGN(r1_batches,
+ MutateByKey(r1_batches, "key", "key2", true, true));
+ ASSERT_OK_AND_ASSIGN(exp_nokey_batches,
+ MutateByKey(exp_nokey_batches, "key", "key2", true,
true));
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches,
+ AsofJoinNodeOptions("time", {"key2"}, tolerance));
+ });
+ }
+ static void DoMutateNullKey(BasicTest& basic_tests) {
basic_tests.RunMutateNullKey(); }
+ void RunMutateEmptyKey() {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_emptykey_batches, B exp_batches) {
+ ASSERT_OK_AND_ASSIGN(r0_batches,
+ MutateByKey(r0_batches, "key", "key", false, false,
true));
+ ASSERT_OK_AND_ASSIGN(r1_batches,
+ MutateByKey(r1_batches, "key", "key", false, false,
true));
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_emptykey_batches,
+ AsofJoinNodeOptions("time", {}, tolerance));
+ });
+ }
+ static void DoMutateEmptyKey(BasicTest& basic_tests) {
+ basic_tests.RunMutateEmptyKey();
+ }
+ template <typename BatchesRunner>
+ void RunBatches(BatchesRunner batches_runner) {
+ std::vector<std::shared_ptr<DataType>> all_types = {
+ utf8(),
+ large_utf8(),
+ binary(),
+ large_binary(),
+ int8(),
+ int16(),
+ int32(),
+ int64(),
+ uint8(),
+ uint16(),
+ uint32(),
+ uint64(),
+ date32(),
+ date64(),
+ time32(TimeUnit::MILLI),
+ time32(TimeUnit::SECOND),
+ time64(TimeUnit::NANO),
+ time64(TimeUnit::MICRO),
+ timestamp(TimeUnit::NANO, "UTC"),
+ timestamp(TimeUnit::MICRO, "UTC"),
+ timestamp(TimeUnit::MILLI, "UTC"),
+ timestamp(TimeUnit::SECOND, "UTC"),
+ float32(),
+ float64()};
+ using T = const std::shared_ptr<DataType>;
+ // byte_width > 1 below allows fitting the tested data
+ auto time_types = init_types(
+ all_types, [](T& t) { return t->byte_width() > 1 &&
!is_floating(t->id()); });
+ auto key_types = init_types(all_types, [](T& t) { return
!is_floating(t->id()); });
+ auto l_types = init_types(all_types, [](T& t) { return true; });
+ auto r0_types = init_types(all_types, [](T& t) { return t->byte_width() >
1; });
+ auto r1_types = init_types(all_types, [](T& t) { return t->byte_width() >
1; });
+
+ // sample a limited number of type-combinations to keep the runnning time
reasonable
+ // the scoped-traces below help reproduce a test failure, should it happen
+ auto start_time = std::chrono::system_clock::now();
+ auto seed = start_time.time_since_epoch().count();
+ ARROW_SCOPED_TRACE("Types seed: ", seed);
+ std::default_random_engine engine(static_cast<unsigned int>(seed));
+ std::uniform_int_distribution<size_t> time_distribution(0,
time_types.size() - 1);
+ std::uniform_int_distribution<size_t> key_distribution(0, key_types.size()
- 1);
+ std::uniform_int_distribution<size_t> l_distribution(0, l_types.size() -
1);
+ std::uniform_int_distribution<size_t> r0_distribution(0, r0_types.size() -
1);
+ std::uniform_int_distribution<size_t> r1_distribution(0, r1_types.size() -
1);
+
+ for (int i = 0; i < 1000; i++) {
+ auto time_type = time_types[time_distribution(engine)];
+ ARROW_SCOPED_TRACE("Time type: ", *time_type);
+ auto key_type = key_types[key_distribution(engine)];
+ ARROW_SCOPED_TRACE("Key type: ", *key_type);
+ auto l_type = l_types[l_distribution(engine)];
+ ARROW_SCOPED_TRACE("Left type: ", *l_type);
+ auto r0_type = r0_types[r0_distribution(engine)];
+ ARROW_SCOPED_TRACE("Right-0 type: ", *r0_type);
+ auto r1_type = r1_types[r1_distribution(engine)];
+ ARROW_SCOPED_TRACE("Right-1 type: ", *r1_type);
+
+ RunTypes({time_type, key_type, l_type, r0_type, r1_type},
batches_runner);
+
+ auto end_time = std::chrono::system_clock::now();
+ std::chrono::duration<double> diff = end_time - start_time;
+ if (diff.count() > 2) {
+ std::cerr << "AsofJoin test reached time limit at iteration " << i <<
std::endl;
+ // this normally happens on slow CI systems, but is fine
+ break;
+ }
Review Comment:
I removed this print.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]