rtpsw commented on code in PR #35874:
URL: https://github.com/apache/arrow/pull/35874#discussion_r1232191968
##########
cpp/src/arrow/acero/asof_join_node_test.cc:
##########
@@ -1381,36 +1587,85 @@ void TestBackpressure(BatchesMaker maker, int
num_batches, int batch_size,
ASSERT_OK_AND_ASSIGN(auto r0_batches, make_shift(r0_schema, 1));
ASSERT_OK_AND_ASSIGN(auto r1_batches, make_shift(r1_schema, 2));
- Declaration l_src = {
- "source", SourceNodeOptions(
- l_schema, MakeDelayedGen(l_batches, "0:fast", fast_delay,
noisy))};
- Declaration r0_src = {
- "source", SourceNodeOptions(
- r0_schema, MakeDelayedGen(r0_batches, "1:slow",
slow_delay, noisy))};
- Declaration r1_src = {
- "source", SourceNodeOptions(
- r1_schema, MakeDelayedGen(r1_batches, "2:fast",
fast_delay, noisy))};
+ BackpressureCountingNode::Register();
+ GatedNode::Register();
- Declaration asofjoin = {
- "asofjoin", {l_src, r0_src, r1_src}, GetRepeatedOptions(3, "time",
{"key"}, 1000)};
+ struct BackpressureSourceConfig {
+ std::string name_prefix;
+ bool is_gated;
+ std::shared_ptr<Schema> schema;
+ decltype(l_batches) batches;
- ASSERT_OK_AND_ASSIGN(std::unique_ptr<RecordBatchReader> batch_reader,
- DeclarationToReader(asofjoin, /*use_threads=*/false));
+ std::string name() const {
+ return name_prefix + ";" + (is_gated ? "gated" : "ungated");
+ }
+ };
+
+ Gate gate;
+ GatedNodeOptions gate_options(&gate);
+
+ // Two ungated and one gated
+ std::vector<BackpressureSourceConfig> source_configs = {
+ {"0", false, l_schema, l_batches},
+ {"1", true, r0_schema, r0_batches},
+ {"2", false, r1_schema, r1_batches},
+ };
- int64_t total_length = 0;
- for (;;) {
- ASSERT_OK_AND_ASSIGN(auto batch, batch_reader->Next());
- if (!batch) {
- break;
+ std::vector<BackpressureCounters> bp_counters(source_configs.size());
+ std::vector<Declaration> src_decls;
+ std::vector<std::shared_ptr<BackpressureCountingNodeOptions>> bp_options;
+ std::vector<Declaration::Input> bp_decls;
+ for (size_t i = 0; i < source_configs.size(); i++) {
+ const auto& config = source_configs[i];
+
+ src_decls.emplace_back("source",
+ SourceNodeOptions(config.schema,
GetGen(config.batches)));
+ bp_options.push_back(
+ std::make_shared<BackpressureCountingNodeOptions>(&bp_counters[i]));
+ std::shared_ptr<ExecNodeOptions> options = bp_options.back();
+ std::vector<Declaration::Input> bp_in = {src_decls.back()};
+ Declaration bp_decl = {BackpressureCountingNode::kFactoryName, bp_in,
+ std::move(options)};
+ if (config.is_gated) {
+ bp_decl = {GatedNode::kFactoryName, {bp_decl}, gate_options};
}
- total_length += batch->num_rows();
+ bp_decls.push_back(bp_decl);
+ }
+
+ Declaration asofjoin = {"asofjoin", bp_decls,
+ GetRepeatedOptions(source_configs.size(), "time",
{"key"}, 0)};
+
+ ASSERT_OK_AND_ASSIGN(std::shared_ptr<internal::ThreadPool> tpool,
+ internal::ThreadPool::Make(1));
+ ExecContext exec_ctx(default_memory_pool(), tpool.get());
+ Future<BatchesWithCommonSchema> batches_fut =
+ DeclarationToExecBatchesAsync(asofjoin, exec_ctx);
+
+ auto has_bp_been_applied = [&] {
+ int total_paused = 0;
+ for (const auto& counters : bp_counters) {
+ total_paused += counters.pause_count;
+ }
+ // One of the inputs is gated. The other two will eventually be paused by
the asof
+ // join node
+ return total_paused >= 2;
Review Comment:
Empirically, yes. I'll try this.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]