rtpsw commented on code in PR #13880:
URL: https://github.com/apache/arrow/pull/13880#discussion_r955462463
##########
cpp/src/arrow/compute/exec/asof_join_node_test.cc:
##########
@@ -74,237 +226,723 @@ void CheckRunOutput(const BatchesWithSchema& l_batches,
/*same_chunk_layout=*/true, /*flatten=*/true);
}
-void DoRunBasicTest(const std::vector<util::string_view>& l_data,
- const std::vector<util::string_view>& r0_data,
- const std::vector<util::string_view>& r1_data,
- const std::vector<util::string_view>& exp_data, int64_t
tolerance) {
- auto l_schema =
- schema({field("time", int64()), field("key", int32()), field("l_v0",
float64())});
- auto r0_schema =
- schema({field("time", int64()), field("key", int32()), field("r0_v0",
float64())});
- auto r1_schema =
- schema({field("time", int64()), field("key", int32()), field("r1_v0",
float32())});
-
- auto exp_schema = schema({
- field("time", int64()),
- field("key", int32()),
- field("l_v0", float64()),
- field("r0_v0", float64()),
- field("r1_v0", float32()),
- });
-
- // Test three table join
- BatchesWithSchema l_batches, r0_batches, r1_batches, exp_batches;
- l_batches = MakeBatchesFromString(l_schema, l_data);
- r0_batches = MakeBatchesFromString(r0_schema, r0_data);
- r1_batches = MakeBatchesFromString(r1_schema, r1_data);
- exp_batches = MakeBatchesFromString(exp_schema, exp_data);
- CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key",
- tolerance);
-}
-
-void DoRunInvalidTypeTest(const std::shared_ptr<Schema>& l_schema,
- const std::shared_ptr<Schema>& r_schema) {
- BatchesWithSchema l_batches = MakeBatchesFromString(l_schema, {R"([])"});
- BatchesWithSchema r_batches = MakeBatchesFromString(r_schema, {R"([])"});
-
+#define CHECK_RUN_OUTPUT(by_key_type)
\
+ void CheckRunOutput(
\
+ const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches,
\
+ const BatchesWithSchema& r1_batches, const BatchesWithSchema&
exp_batches, \
+ const FieldRef time, by_key_type keys, const int64_t tolerance) {
\
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches,
\
+ AsofJoinNodeOptions(time, keys, tolerance));
\
+ }
+
+EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT)
+
+void DoInvalidPlanTest(const BatchesWithSchema& l_batches,
+ const BatchesWithSchema& r_batches,
+ const AsofJoinNodeOptions& join_options,
+ const std::string& expected_error_str,
+ bool then_run_plan = false) {
ExecContext exec_ctx;
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx));
- AsofJoinNodeOptions join_options("time", "key", 0);
Declaration join{"asofjoin", join_options};
join.inputs.emplace_back(Declaration{
"source", SourceNodeOptions{l_batches.schema, l_batches.gen(false,
false)}});
join.inputs.emplace_back(Declaration{
"source", SourceNodeOptions{r_batches.schema, r_batches.gen(false,
false)}});
- ASSERT_RAISES(Invalid, join.AddToPlan(plan.get()));
+ if (then_run_plan) {
+ AsyncGenerator<util::optional<ExecBatch>> sink_gen;
+ ASSERT_OK(Declaration::Sequence({join, {"sink",
SinkNodeOptions{&sink_gen}}})
+ .AddToPlan(plan.get()));
+ EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(Invalid,
+
::testing::HasSubstr(expected_error_str),
+ StartAndCollect(plan.get(),
sink_gen));
+ } else {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr(expected_error_str),
+ join.AddToPlan(plan.get()));
+ }
+}
+
+void DoRunInvalidPlanTest(const BatchesWithSchema& l_batches,
+ const BatchesWithSchema& r_batches,
+ const AsofJoinNodeOptions& join_options,
+ const std::string& expected_error_str) {
+ DoInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str);
+}
+
+void DoRunInvalidPlanTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema,
+ const AsofJoinNodeOptions& join_options,
+ const std::string& expected_error_str) {
+ BatchesWithSchema l_batches = MakeBatchesFromNumString(l_schema, {R"([])"});
+ BatchesWithSchema r_batches = MakeBatchesFromNumString(r_schema, {R"([])"});
+
+ return DoRunInvalidPlanTest(l_batches, r_batches, join_options,
expected_error_str);
+}
+
+void DoRunInvalidPlanTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema, int64_t
tolerance,
+ const std::string& expected_error_str) {
+ DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("time", "key",
tolerance),
+ expected_error_str);
+}
+
+void DoRunInvalidTypeTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, 0, "Unsupported type for ");
+}
+
+void DoRunInvalidToleranceTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, -1,
+ "AsOfJoin tolerance must be non-negative but is ");
+}
+
+void DoRunMissingKeysTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : No
match");
+}
+
+void DoRunEmptyByKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("time", {}, 0),
+ "AsOfJoin by_key must not be empty");
+}
+
+void DoRunMissingOnKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("invalid_time",
"key", 0),
+ "Bad join key on table : No match");
+}
+
+void DoRunMissingByKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("time",
"invalid_key", 0),
+ "Bad join key on table : No match");
+}
+
+void DoRunNestedOnKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions({0, "time"},
"key", 0),
+ "Bad join key on table : No match");
+}
+
+void DoRunNestedByKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions("time",
FieldRef{0, 1}, 0),
+ "Bad join key on table : No match");
+}
+
+void DoRunAmbiguousOnKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table :
Multiple matches");
+}
+
+void DoRunAmbiguousByKeyTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table :
Multiple matches");
+}
+
+std::string GetJsonString(int n_rows, int n_cols, bool unordered = false) {
+ std::stringstream s;
+ s << '[';
+ for (int i = 0; i < n_rows; i++) {
+ if (i > 0) {
+ s << ", ";
+ }
+ s << '[';
+ for (int j = 0; j < n_cols; j++) {
+ if (j > 0) {
+ s << ", " << j;
+ } else if (j < 2) {
+ s << (i ^ unordered);
+ } else {
+ s << i;
+ }
+ }
+ s << ']';
+ }
+ s << ']';
+ return s.str();
+}
+
+void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered,
+ const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema,
+ const AsofJoinNodeOptions& join_options,
+ const std::string& expected_error_str) {
+ ASSERT_TRUE(l_unordered || r_unordered);
+ int n_rows = 5;
+ std::string l_str = GetJsonString(n_rows, l_schema->num_fields(),
l_unordered);
+ std::string r_str = GetJsonString(n_rows, r_schema->num_fields(),
r_unordered);
+ BatchesWithSchema l_batches = MakeBatchesFromNumString(l_schema, {l_str});
+ BatchesWithSchema r_batches = MakeBatchesFromNumString(r_schema, {r_str});
+
+ return DoInvalidPlanTest(l_batches, r_batches, join_options,
expected_error_str,
+ /*then_run_plan=*/true);
+}
+
+void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered,
+ const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ DoRunUnorderedPlanTest(l_unordered, r_unordered, l_schema, r_schema,
+ AsofJoinNodeOptions("time", "key", 1000),
+ "out-of-order on-key values");
+}
+
+void DoRunNullByKeyPlanTest(const std::shared_ptr<Schema>& l_schema,
+ const std::shared_ptr<Schema>& r_schema) {
+ AsofJoinNodeOptions join_options{"time", "key2", 1000};
+ std::string expected_error_str = "unexpected null by-key values";
+ int n_rows = 5;
+ std::string l_str = GetJsonString(n_rows, l_schema->num_fields());
+ std::string r_str = GetJsonString(n_rows, r_schema->num_fields());
+ BatchesWithSchema l_batches = MakeBatchesFromNumString(l_schema, {l_str});
+ BatchesWithSchema r_batches = MakeBatchesFromNumString(r_schema, {r_str});
+ l_batches = MutateByKey(l_batches, "key", "key2", true, true);
+ r_batches = MutateByKey(r_batches, "key", "key2", true, true);
+
+ return DoInvalidPlanTest(l_batches, r_batches, join_options,
expected_error_str,
+ /*then_run_plan=*/true);
}
+struct BasicTestTypes {
+ std::shared_ptr<DataType> time, key, l_val, r0_val, r1_val;
+};
+
+struct BasicTest {
+ BasicTest(const std::vector<util::string_view>& l_data,
+ const std::vector<util::string_view>& r0_data,
+ const std::vector<util::string_view>& r1_data,
+ const std::vector<util::string_view>& exp_nokey_data,
+ const std::vector<util::string_view>& exp_data, int64_t tolerance)
+ : l_data(std::move(l_data)),
+ r0_data(std::move(r0_data)),
+ r1_data(std::move(r1_data)),
+ exp_nokey_data(std::move(exp_nokey_data)),
+ exp_data(std::move(exp_data)),
+ tolerance(tolerance) {}
+
+ template <typename TypeCond>
+ static inline void init_types(const std::vector<std::shared_ptr<DataType>>&
all_types,
+ std::vector<std::shared_ptr<DataType>>& types,
+ TypeCond type_cond) {
+ if (types.size() == 0) {
+ for (auto type : all_types) {
+ if (type_cond(type)) {
+ types.push_back(type);
+ }
+ }
+ }
+ }
+
+ void RunSingleByKey(std::vector<std::shared_ptr<DataType>> time_types = {},
+ std::vector<std::shared_ptr<DataType>> key_types = {},
+ std::vector<std::shared_ptr<DataType>> l_types = {},
+ std::vector<std::shared_ptr<DataType>> r0_types = {},
+ std::vector<std::shared_ptr<DataType>> r1_types = {}) {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_batches) {
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time",
"key",
+ tolerance);
+ });
+ }
+ void RunDoubleByKey(std::vector<std::shared_ptr<DataType>> time_types = {},
+ std::vector<std::shared_ptr<DataType>> key_types = {},
+ std::vector<std::shared_ptr<DataType>> l_types = {},
+ std::vector<std::shared_ptr<DataType>> r0_types = {},
+ std::vector<std::shared_ptr<DataType>> r1_types = {}) {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_batches) {
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time",
+ {"key", "key"}, tolerance);
+ });
+ }
+ void RunMutateByKey(std::vector<std::shared_ptr<DataType>> time_types = {},
+ std::vector<std::shared_ptr<DataType>> key_types = {},
+ std::vector<std::shared_ptr<DataType>> l_types = {},
+ std::vector<std::shared_ptr<DataType>> r0_types = {},
+ std::vector<std::shared_ptr<DataType>> r1_types = {}) {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_batches) {
+ l_batches = MutateByKey(l_batches, "key", "key2");
+ r0_batches = MutateByKey(r0_batches, "key", "key2");
+ r1_batches = MutateByKey(r1_batches, "key", "key2");
+ exp_batches = MutateByKey(exp_batches, "key", "key2");
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time",
+ {"key", "key2"}, tolerance);
+ });
+ }
+ void RunMutateNoKey(std::vector<std::shared_ptr<DataType>> time_types = {},
+ std::vector<std::shared_ptr<DataType>> key_types = {},
+ std::vector<std::shared_ptr<DataType>> l_types = {},
+ std::vector<std::shared_ptr<DataType>> r0_types = {},
+ std::vector<std::shared_ptr<DataType>> r1_types = {}) {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_batches) {
+ l_batches = MutateByKey(l_batches, "key", "key2", true);
+ r0_batches = MutateByKey(r0_batches, "key", "key2", true);
+ r1_batches = MutateByKey(r1_batches, "key", "key2", true);
+ exp_nokey_batches = MutateByKey(exp_nokey_batches, "key", "key2", true);
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches,
"time", "key2",
+ tolerance);
+ });
+ }
+ void RunMutateNullKey(std::vector<std::shared_ptr<DataType>> time_types = {},
+ std::vector<std::shared_ptr<DataType>> key_types = {},
+ std::vector<std::shared_ptr<DataType>> l_types = {},
+ std::vector<std::shared_ptr<DataType>> r0_types = {},
+ std::vector<std::shared_ptr<DataType>> r1_types = {}) {
+ using B = BatchesWithSchema;
+ RunBatches([this](B l_batches, B r0_batches, B r1_batches, B
exp_nokey_batches,
+ B exp_batches) {
+ l_batches = MutateByKey(l_batches, "key", "key2", true, true);
+ r0_batches = MutateByKey(r0_batches, "key", "key2", true, true);
+ r1_batches = MutateByKey(r1_batches, "key", "key2", true, true);
+ exp_nokey_batches = MutateByKey(exp_nokey_batches, "key", "key2", true,
true);
+ CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches,
+ AsofJoinNodeOptions("time", "key2", tolerance,
+ /*nullable_by_key=*/true));
+ });
+ }
+ template <typename BatchesRunner>
+ void RunBatches(BatchesRunner batches_runner,
+ std::vector<std::shared_ptr<DataType>> time_types = {},
+ std::vector<std::shared_ptr<DataType>> key_types = {},
+ std::vector<std::shared_ptr<DataType>> l_types = {},
+ std::vector<std::shared_ptr<DataType>> r0_types = {},
+ std::vector<std::shared_ptr<DataType>> r1_types = {}) {
+ std::vector<std::shared_ptr<DataType>> all_types = {
+ utf8(),
+ large_utf8(),
+ binary(),
+ large_binary(),
+ int8(),
+ int16(),
+ int32(),
+ int64(),
+ uint8(),
+ uint16(),
+ uint32(),
+ uint64(),
+ date32(),
+ date64(),
+ time32(TimeUnit::MILLI),
+ time32(TimeUnit::SECOND),
+ time64(TimeUnit::NANO),
+ time64(TimeUnit::MICRO),
+ timestamp(TimeUnit::NANO, "UTC"),
+ timestamp(TimeUnit::MICRO, "UTC"),
+ timestamp(TimeUnit::MILLI, "UTC"),
+ timestamp(TimeUnit::SECOND, "UTC"),
+ float32(),
+ float64()};
+ using T = const std::shared_ptr<DataType>;
+ // byte_width > 1 below allows fitting the tested data
+ init_types(all_types, time_types,
+ [](T& t) { return t->byte_width() > 1 && !is_floating(t->id());
});
+ ASSERT_NE(0, time_types.size());
+ init_types(all_types, key_types, [](T& t) { return !is_floating(t->id());
});
+ ASSERT_NE(0, key_types.size());
+ init_types(all_types, l_types, [](T& t) { return true; });
+ ASSERT_NE(0, l_types.size());
+ init_types(all_types, r0_types, [](T& t) { return t->byte_width() > 1; });
+ ASSERT_NE(0, r0_types.size());
+ init_types(all_types, r1_types, [](T& t) { return t->byte_width() > 1; });
+ ASSERT_NE(0, r1_types.size());
+
+ // sample a limited number of type-combinations to keep the runnning time
reasonable
+ // the scoped-traces below help reproduce a test failure, should it happen
+ auto start_time = std::chrono::system_clock::now();
+ auto seed = start_time.time_since_epoch().count();
+ ARROW_SCOPED_TRACE("Types seed: ", seed);
+ std::default_random_engine engine(static_cast<unsigned int>(seed));
+ std::uniform_int_distribution<size_t> time_distribution(0,
time_types.size() - 1);
+ std::uniform_int_distribution<size_t> key_distribution(0, key_types.size()
- 1);
+ std::uniform_int_distribution<size_t> l_distribution(0, l_types.size() -
1);
+ std::uniform_int_distribution<size_t> r0_distribution(0, r0_types.size() -
1);
+ std::uniform_int_distribution<size_t> r1_distribution(0, r1_types.size() -
1);
+
+ for (int i = 0; i < 1000; i++) {
+ auto time_type = time_types[time_distribution(engine)];
+ ARROW_SCOPED_TRACE("Time type: ", *time_type);
+ auto key_type = key_types[key_distribution(engine)];
+ ARROW_SCOPED_TRACE("Key type: ", *key_type);
+ auto l_type = l_types[l_distribution(engine)];
+ ARROW_SCOPED_TRACE("Left type: ", *l_type);
+ auto r0_type = r0_types[r0_distribution(engine)];
+ ARROW_SCOPED_TRACE("Right-0 type: ", *r0_type);
+ auto r1_type = r1_types[r1_distribution(engine)];
+ ARROW_SCOPED_TRACE("Right-1 type: ", *r1_type);
+
+ RunTypes({time_type, key_type, l_type, r0_type, r1_type},
batches_runner);
+
+ auto end_time = std::chrono::system_clock::now();
+ std::chrono::duration<double> diff = end_time - start_time;
+ if (diff.count() > 2) {
+ std::cerr << "AsofJoin test reached time limit at iteration " << i <<
std::endl;
+ // this normally happens on slow CI systems, but is fine
+ break;
+ }
+ }
+ }
+ template <typename BatchesRunner>
+ void RunTypes(BasicTestTypes basic_test_types, BatchesRunner batches_runner)
{
+ const BasicTestTypes& b = basic_test_types;
+ auto l_schema =
+ schema({field("time", b.time), field("key", b.key), field("l_v0",
b.l_val)});
+ auto r0_schema =
+ schema({field("time", b.time), field("key", b.key), field("r0_v0",
b.r0_val)});
+ auto r1_schema =
+ schema({field("time", b.time), field("key", b.key), field("r1_v0",
b.r1_val)});
+
+ auto exp_schema = schema({
+ field("time", b.time),
+ field("key", b.key),
+ field("l_v0", b.l_val),
+ field("r0_v0", b.r0_val),
+ field("r1_v0", b.r1_val),
+ });
+
+ // Test three table join
+ BatchesWithSchema l_batches, r0_batches, r1_batches, exp_nokey_batches,
exp_batches;
+ l_batches = MakeBatchesFromNumString(l_schema, l_data);
+ r0_batches = MakeBatchesFromNumString(r0_schema, r0_data);
+ r1_batches = MakeBatchesFromNumString(r1_schema, r1_data);
+ exp_nokey_batches = MakeBatchesFromNumString(exp_schema, exp_nokey_data);
+ exp_batches = MakeBatchesFromNumString(exp_schema, exp_data);
+ batches_runner(l_batches, r0_batches, r1_batches, exp_nokey_batches,
exp_batches);
+ }
+
+ std::vector<util::string_view> l_data;
+ std::vector<util::string_view> r0_data;
+ std::vector<util::string_view> r1_data;
+ std::vector<util::string_view> exp_nokey_data;
+ std::vector<util::string_view> exp_data;
+ int64_t tolerance;
+};
+
class AsofJoinTest : public testing::Test {};
-TEST(AsofJoinTest, TestBasic1) {
+#define ASOFJOIN_TEST_SET(name, num) \
Review Comment:
I fixed this.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]