uchenily commented on issue #45847: URL: https://github.com/apache/arrow/issues/45847#issuecomment-2732075625
```cpp #include <arrow/acero/exec_plan.h> #include <arrow/acero/options.h> #include <arrow/api.h> #include <arrow/result.h> #include <arrow/status.h> #include <arrow/util/async_generator.h> #include <iostream> #include <memory> #include <vector> namespace cp = arrow::compute; namespace ac = arrow::acero; static constexpr size_t MAX_BATCH_SIZE = 1 << 15; namespace arrow_ext { class RandomBatchGenerator { public: std::shared_ptr<arrow::Schema> schema; int64_t min_val_; int64_t max_val_; int64_t current_; RandomBatchGenerator(std::shared_ptr<arrow::Schema> schema, int64_t min_val, int64_t max_val) : schema{schema} , min_val_{min_val} , max_val_{max_val} , current_{min_val} {}; arrow::Result<std::shared_ptr<arrow::RecordBatch>> Generate(size_t num_rows) { num_rows_ = num_rows; for (std::shared_ptr<arrow::Field> field : schema->fields()) { ARROW_RETURN_NOT_OK(arrow::VisitTypeInline(*field->type(), this)); } auto temp = std::vector<std::shared_ptr<arrow::Array>>{}; std::swap(arrays_, temp); return arrow::RecordBatch::Make(schema, num_rows, temp); } arrow::Status Visit(const arrow::DataType &type) { return arrow::Status::NotImplemented("Generating data for", type.ToString()); } arrow::Status Visit(const arrow::Int64Type &) { auto builder = arrow::Int64Builder(); // auto d = std::uniform_int_distribution{min_val_, max_val_}; ARROW_RETURN_NOT_OK(builder.Reserve(num_rows_)); for (size_t i = 0; i < num_rows_; ++i) { // ARROW_RETURN_NOT_OK(builder.Append(d(gen_))); ARROW_RETURN_NOT_OK(builder.Append(current_++)); } if (current_ > max_val_) { current_ = min_val_; } ARROW_ASSIGN_OR_RAISE(auto array, builder.Finish()); arrays_.push_back(array); return arrow::Status::OK(); } protected: // std::random_device rd_{}; // std::mt19937 gen_{rd_()}; std::vector<std::shared_ptr<arrow::Array>> arrays_; size_t num_rows_; }; inline auto MakeRandomSourceGenerator(std::shared_ptr<arrow::Schema> col_names, size_t num_batches, int64_t min_val, int64_t max_val) -> arrow::AsyncGenerator<std::optional<cp::ExecBatch>> { struct State { State(std::shared_ptr<arrow::Schema> schema, size_t num_batches, int64_t min_val, int64_t max_val) : schema_{schema} , num_batches_{num_batches} , generator_{schema_, min_val, max_val} {} std::atomic<size_t> index_{0}; std::shared_ptr<arrow::Schema> schema_; const size_t num_batches_; RandomBatchGenerator generator_; }; auto state = std::make_shared<State>(col_names, num_batches, min_val, max_val); return [shared = state]() -> arrow::Future<std::optional<cp::ExecBatch>> { if (shared->index_.load() == shared->num_batches_) { return arrow::AsyncGeneratorEnd<std::optional<cp::ExecBatch>>(); } shared->index_.fetch_add(1); auto maybe_batch = shared->generator_.Generate(MAX_BATCH_SIZE); if (!maybe_batch.ok()) { return arrow::AsyncGeneratorEnd<std::optional<cp::ExecBatch>>(); } std::cout << "Generating batch ... [" << shared->index_.load() << "/" << shared->num_batches_ << "]\n"; return arrow::Future<std::optional<cp::ExecBatch>>::MakeFinished( cp::ExecBatch{**maybe_batch}); }; } } // namespace arrow_ext auto HashJoin(size_t num_probe_batches, size_t num_build_batches) -> arrow::Status { auto probe_schema = std::make_shared<arrow::Schema>(std::vector{ std::make_shared<arrow::Field>("orderkey", arrow::int64()), std::make_shared<arrow::Field>("custkey_0", arrow::int64()), }); auto build_schema = std::make_shared<arrow::Schema>(std::vector{ std::make_shared<arrow::Field>("custkey", arrow::int64()), }); auto probe_generator0 = arrow_ext::MakeRandomSourceGenerator(probe_schema, num_probe_batches, 0, 3000000); auto build_generator1 = arrow_ext::MakeRandomSourceGenerator(build_schema, num_build_batches, 0, 3000000); auto probe_source_options0 = ac::SourceNodeOptions{probe_schema, probe_generator0}; auto build_source_options1 = ac::SourceNodeOptions{build_schema, build_generator1}; ac::Declaration left{"source", std::move(probe_source_options0)}; ac::Declaration right{"source", std::move(build_source_options1)}; ac::HashJoinNodeOptions join_opts{ac::JoinType::RIGHT_OUTER, {"custkey_0"}, {"custkey"}, cp::literal(true)}; ac::Declaration hashjoin{ "hashjoin", {std::move(left), std::move(right)}, join_opts }; bool has_hashnode = true; if (!has_hashnode) { // return ac::DeclarationToStatus(std::move(hashjoin)); auto table = ac::DeclarationToTable(std::move(hashjoin)); std::cout << "table num_rows: " << (*table)->num_rows() << '\n'; std::cout << "table (sliced): " << (*table)->Slice(0, 10)->ToString() << '\n'; return arrow::Status::OK(); } else { // hash sum std::vector<cp::Aggregate> aggrs; // std::shared_ptr<cp::FunctionOptions> options // = std::make_shared<cp::ScalarAggregateOptions>(); // aggrs.emplace_back("hash_sum", options, "y", "sum"); std::shared_ptr<cp::FunctionOptions> options = std::make_shared<cp::CountOptions>(cp::CountOptions::ONLY_VALID); aggrs.emplace_back("hash_count", options, "orderkey", "count"); auto key_fields = std::vector<arrow::FieldRef>({"custkey"}); auto aggregate_options = ac::AggregateNodeOptions{/*aggregates=*/aggrs, /*keys=*/key_fields}; auto aggregate = ac::Declaration{"aggregate", {std::move(hashjoin)}, std::move(aggregate_options)}; // return ac::DeclarationToStatus(std::move(aggregate)); // auto schema = ac::DeclarationToSchema(std::move(aggregate)); // LOG_DEBUG("schema: {}", (*schema)->ToString()); std::cout << ac::DeclarationToString(aggregate).ValueOrDie() << '\n'; auto table = ac::DeclarationToTable(std::move(aggregate)); std::cout << "table num_rows: " << (*table)->num_rows() << '\n'; std::cout << "table (sliced): " << (*table)->Slice(0, 10)->ToString() << '\n'; std::cout << "probe-side rows: " << num_probe_batches * MAX_BATCH_SIZE << '\n'; std::cout << "build-side rows: " << num_build_batches * MAX_BATCH_SIZE << '\n'; return arrow::Status::OK(); } } auto main(int argc, char **argv) -> int { if (argc != 3) { std::cout << "Usage: " << argv[0] << " {num_probe_batches} {num_build_batches}\n"; return 0; } size_t num_probe_batches = std::atoi(argv[1]); size_t num_build_batches = std::atoi(argv[2]); auto status = HashJoin(num_probe_batches, num_build_batches); if (!status.ok()) { std::cerr << "Error occurred: " << status.message() << '\n'; return EXIT_FAILURE; } return EXIT_SUCCESS; } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org