westonpace commented on code in PR #14118:
URL: https://github.com/apache/arrow/pull/14118#discussion_r1045135030
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -572,6 +600,17 @@ Result<std::shared_ptr<Schema>> ExtractSchemaToBind(const
compute::Declaration&
return bind_schema;
}
+Result<std::unique_ptr<substrait::RelCommon>> GetRelCommonEmit(
Review Comment:
There is no need for this function to return a result (and then you can get
rid of the move on the return value).
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -562,6 +563,33 @@ Result<std::shared_ptr<Schema>> ExtractSchemaToBind(const
compute::Declaration&
} else if (declr.factory_name == "filter") {
auto input_declr = std::get<compute::Declaration>(declr.inputs[0]);
ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr));
+ } else if (declr.factory_name == "project") {
+ auto input_declr = std::get<compute::Declaration>(declr.inputs[0]);
+ ARROW_ASSIGN_OR_RAISE(auto input_schema, ExtractSchemaToBind(input_declr));
+ const int num_fields_before_proj = input_schema->num_fields();
+ const auto& opts = checked_cast<const
compute::ProjectNodeOptions&>(*(declr.options));
+ const auto& exprs = opts.expressions;
+ int i = 0;
+ bind_schema = input_schema;
+ for (const auto& expr : exprs) {
+ std::shared_ptr<Field> project_field;
+ auto bound_expr = expr.Bind(*input_schema);
+ if (auto* expr_call = bound_expr->call()) {
+ project_field = field(expr_call->function_name,
+ expr_call->kernel->signature->out_type().type());
+ } else if (auto* field_ref = bound_expr->field_ref()) {
+ ARROW_ASSIGN_OR_RAISE(FieldPath field_path,
field_ref->FindOne(*input_schema));
+ ARROW_ASSIGN_OR_RAISE(project_field, field_path.Get(*input_schema));
+ } else if (auto* literal = bound_expr->literal()) {
+ project_field =
+ field("field_" + std::to_string(num_fields_before_proj + i),
literal->type());
+ }
Review Comment:
I think there are only three expression types at the moment but should we
have a fallthrough `else` case here returning an error? Maybe someday we will
add a new expression type (e.g. subquery).
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -2082,43 +2103,121 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) {
auto sink_decls,
DeserializePlans(
*serialized_plan, [] { return kNullConsumer; }, ext_id_reg,
&ext_set));
- // filter declaration
- auto& roundtripped_filter =
std::get<compute::Declaration>(sink_decls[0].inputs[0]);
- const auto& filter_opts =
- checked_cast<const
compute::FilterNodeOptions&>(*(roundtripped_filter.options));
- auto roundtripped_expr = filter_opts.filter_expression;
-
- if (auto* call = roundtripped_expr.call()) {
- EXPECT_EQ(call->function_name, "equal");
- auto args = call->arguments;
- auto left_index = args[0].field_ref()->field_path()->indices()[0];
- EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left);
- auto right_index = args[1].field_ref()->field_path()->indices()[0];
- EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right);
- }
- // scan declaration
+ // assert filter declaration
+ const auto& roundtripped_filter =
+ std::get<compute::Declaration>(sink_decls[0].inputs[0]);
+ AssertFilterRelation(roundtripped_filter, std::move(filter_expr),
dummy_schema);
+ // assert scan declaration
const auto& roundtripped_scan =
std::get<compute::Declaration>(roundtripped_filter.inputs[0]);
- const auto& dataset_opts =
- checked_cast<const
dataset::ScanNodeOptions&>(*(roundtripped_scan.options));
- const auto& roundripped_ds = dataset_opts.dataset;
- EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema));
- ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments());
- ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments());
+ AssertScanRelation(roundtripped_scan, dataset, dummy_schema);
+ // assert results
+ AssertPlanExecutionResult(expected_table, roundtripped_filter, dummy_schema,
+ exec_context);
+}
- auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs));
- auto expected_frg_vec = IteratorToVector(std::move(expected_frgs));
- EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size());
- int64_t idx = 0;
- for (auto fragment : expected_frg_vec) {
- const auto* l_frag = checked_cast<const
dataset::FileFragment*>(fragment.get());
- const auto* r_frag =
- checked_cast<const
dataset::FileFragment*>(roundtrip_frg_vec[idx++].get());
- EXPECT_TRUE(l_frag->Equals(*r_frag));
+TEST(Substrait, FilterProjectPlanRoundTripping) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#endif
+ compute::ExecContext exec_context;
+ arrow::dataset::internal::Initialize();
+
+ auto dummy_schema = schema(
+ {field("key", int32()), field("shared", int32()), field("distinct",
int32())});
+
+ // creating a dummy dataset using a dummy table
+ auto table = TableFromJSON(dummy_schema, {R"([
+ [1, 1, 10],
+ [3, 4, 20]
+ ])",
+ R"([
+ [0, 2, 1],
+ [1, 3, 2],
+ [4, 1, 3],
+ [3, 1, 3],
+ [1, 2, 5]
+ ])",
+ R"([
+ [2, 2, 12],
+ [5, 3, 12],
+ [1, 3, 12]
+ ])"});
+
+ auto format = std::make_shared<arrow::dataset::IpcFileFormat>();
+ auto filesystem = std::make_shared<fs::LocalFileSystem>();
+ const std::string file_name = "serde_project_test.arrow";
Review Comment:
Sorry, I know it's been a while, and our utilities have advanced a bit. Can
we change this test up to run on in-memory data instead of bothering with
temporary files?
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -2082,43 +2103,121 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) {
auto sink_decls,
DeserializePlans(
*serialized_plan, [] { return kNullConsumer; }, ext_id_reg,
&ext_set));
- // filter declaration
- auto& roundtripped_filter =
std::get<compute::Declaration>(sink_decls[0].inputs[0]);
- const auto& filter_opts =
- checked_cast<const
compute::FilterNodeOptions&>(*(roundtripped_filter.options));
- auto roundtripped_expr = filter_opts.filter_expression;
-
- if (auto* call = roundtripped_expr.call()) {
- EXPECT_EQ(call->function_name, "equal");
- auto args = call->arguments;
- auto left_index = args[0].field_ref()->field_path()->indices()[0];
- EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left);
- auto right_index = args[1].field_ref()->field_path()->indices()[0];
- EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right);
- }
- // scan declaration
+ // assert filter declaration
+ const auto& roundtripped_filter =
+ std::get<compute::Declaration>(sink_decls[0].inputs[0]);
+ AssertFilterRelation(roundtripped_filter, std::move(filter_expr),
dummy_schema);
+ // assert scan declaration
const auto& roundtripped_scan =
std::get<compute::Declaration>(roundtripped_filter.inputs[0]);
- const auto& dataset_opts =
- checked_cast<const
dataset::ScanNodeOptions&>(*(roundtripped_scan.options));
- const auto& roundripped_ds = dataset_opts.dataset;
- EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema));
- ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments());
- ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments());
+ AssertScanRelation(roundtripped_scan, dataset, dummy_schema);
+ // assert results
+ AssertPlanExecutionResult(expected_table, roundtripped_filter, dummy_schema,
+ exec_context);
+}
- auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs));
- auto expected_frg_vec = IteratorToVector(std::move(expected_frgs));
- EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size());
- int64_t idx = 0;
- for (auto fragment : expected_frg_vec) {
- const auto* l_frag = checked_cast<const
dataset::FileFragment*>(fragment.get());
- const auto* r_frag =
- checked_cast<const
dataset::FileFragment*>(roundtrip_frg_vec[idx++].get());
- EXPECT_TRUE(l_frag->Equals(*r_frag));
+TEST(Substrait, FilterProjectPlanRoundTripping) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#endif
+ compute::ExecContext exec_context;
+ arrow::dataset::internal::Initialize();
+
+ auto dummy_schema = schema(
+ {field("key", int32()), field("shared", int32()), field("distinct",
int32())});
+
+ // creating a dummy dataset using a dummy table
+ auto table = TableFromJSON(dummy_schema, {R"([
+ [1, 1, 10],
+ [3, 4, 20]
+ ])",
+ R"([
+ [0, 2, 1],
+ [1, 3, 2],
+ [4, 1, 3],
+ [3, 1, 3],
+ [1, 2, 5]
+ ])",
+ R"([
+ [2, 2, 12],
+ [5, 3, 12],
+ [1, 3, 12]
+ ])"});
+
+ auto format = std::make_shared<arrow::dataset::IpcFileFormat>();
+ auto filesystem = std::make_shared<fs::LocalFileSystem>();
+ const std::string file_name = "serde_project_test.arrow";
+
+ ASSERT_OK_AND_ASSIGN(auto tempdir,
+
arrow::internal::TemporaryDir::Make("substrait-tempdir-project-"));
+ ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name));
+ std::string file_path_str = file_path.ToString();
+
+ WriteIpcData(file_path_str, filesystem, table);
+
+ std::vector<fs::FileInfo> files;
+ const std::vector<std::string> f_paths = {file_path_str};
+
+ for (const auto& f_path : f_paths) {
+ ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path));
+ files.push_back(std::move(f_file));
}
- ASSERT_OK_AND_ASSIGN(auto rnd_trp_table,
- GetTableFromPlan(roundtripped_filter, exec_context,
dummy_schema));
- EXPECT_TRUE(expected_table->Equals(*rnd_trp_table));
+
+ ASSERT_OK_AND_ASSIGN(auto ds_factory,
dataset::FileSystemDatasetFactory::Make(
+ filesystem, std::move(files),
format, {}));
+ ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema));
+
+ auto scan_options = std::make_shared<dataset::ScanOptions>();
+ scan_options->projection = compute::project({}, {});
+ compute::Expression project_fp_expr = compute::call(
+ "add", {compute::field_ref(FieldPath({1})),
compute::field_ref(FieldPath({2}))});
+
+ std::vector<compute::Expression> project_fp_exprs = {project_fp_expr};
+
+ auto declarations = compute::Declaration::Sequence(
+ {compute::Declaration(
+ {"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}),
+ compute::Declaration(
+ {"project", compute::ProjectNodeOptions{project_fp_exprs}, "p"})});
+
+ // adding the project expression field to schema
+ auto project_schema = schema({field("add(shared,distinct)", int32())});
+
+ ASSERT_OK_AND_ASSIGN(auto expected_table,
+ GetTableFromPlan(declarations, exec_context,
project_schema));
+
+ std::shared_ptr<ExtensionIdRegistry> sp_ext_id_reg =
MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ ExtensionSet ext_set(ext_id_reg);
+
+ ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations,
&ext_set));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto sink_decls,
+ DeserializePlans(
+ *serialized_plan, [] { return kNullConsumer; }, ext_id_reg,
&ext_set));
+ // assert project declaration
+ const auto& roundtripped_emit_project =
+ std::get<compute::Declaration>(sink_decls[0].inputs[0]);
+
+ // assert emit projection declaration
+ // emit node used in the serialization outputs the 4th column since the
+ // 0, 1, 2 original columns from the input data and 4th column being the
+ // projected column.
+ auto expec_project_expr = {compute::field_ref(FieldPath({4}))};
+ AssertProjectRelation(roundtripped_emit_project, expec_project_expr,
project_schema);
+
+ // assert projection declaration
+ auto subs_project_expr = {compute::field_ref(FieldPath({0})),
Review Comment:
This seems correct for the first project node. See previous comment.
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -562,6 +563,33 @@ Result<std::shared_ptr<Schema>> ExtractSchemaToBind(const
compute::Declaration&
} else if (declr.factory_name == "filter") {
auto input_declr = std::get<compute::Declaration>(declr.inputs[0]);
ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr));
+ } else if (declr.factory_name == "project") {
+ auto input_declr = std::get<compute::Declaration>(declr.inputs[0]);
+ ARROW_ASSIGN_OR_RAISE(auto input_schema, ExtractSchemaToBind(input_declr));
+ const int num_fields_before_proj = input_schema->num_fields();
+ const auto& opts = checked_cast<const
compute::ProjectNodeOptions&>(*(declr.options));
+ const auto& exprs = opts.expressions;
+ int i = 0;
+ bind_schema = input_schema;
+ for (const auto& expr : exprs) {
Review Comment:
Minor nit: I have a slight preference for:
```
for (int i = 0; i < static_cast<int>(exprs.size()); i++) {
const auto& expr = exprs[i];
// ...
}
```
over
```
int i = 0;
for (const auto& expr : exprs) {
// ...
i++;
}
```
In other words, prefer range based for loop if you can but don't force it if
you can't. That being said, this preference may just be me (I don't think the
google style guide mandates it) so feel free to ignore this suggestion.
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -562,6 +563,33 @@ Result<std::shared_ptr<Schema>> ExtractSchemaToBind(const
compute::Declaration&
} else if (declr.factory_name == "filter") {
auto input_declr = std::get<compute::Declaration>(declr.inputs[0]);
ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr));
+ } else if (declr.factory_name == "project") {
+ auto input_declr = std::get<compute::Declaration>(declr.inputs[0]);
+ ARROW_ASSIGN_OR_RAISE(auto input_schema, ExtractSchemaToBind(input_declr));
+ const int num_fields_before_proj = input_schema->num_fields();
+ const auto& opts = checked_cast<const
compute::ProjectNodeOptions&>(*(declr.options));
+ const auto& exprs = opts.expressions;
+ int i = 0;
+ bind_schema = input_schema;
+ for (const auto& expr : exprs) {
+ std::shared_ptr<Field> project_field;
+ auto bound_expr = expr.Bind(*input_schema);
+ if (auto* expr_call = bound_expr->call()) {
+ project_field = field(expr_call->function_name,
+ expr_call->kernel->signature->out_type().type());
+ } else if (auto* field_ref = bound_expr->field_ref()) {
+ ARROW_ASSIGN_OR_RAISE(FieldPath field_path,
field_ref->FindOne(*input_schema));
+ ARROW_ASSIGN_OR_RAISE(project_field, field_path.Get(*input_schema));
+ } else if (auto* literal = bound_expr->literal()) {
+ project_field =
+ field("field_" + std::to_string(num_fields_before_proj + i),
literal->type());
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ bind_schema, bind_schema->AddField(
+ num_fields_before_proj +
static_cast<int>(exprs.size()) - 1,
+ std::move(project_field)));
Review Comment:
Instead of:
```
std::shared_ptr<Schema> bind_schema = input_schema;
for (...) {
bind_schema = bind_schema->AddField(..);
}
```
we should do:
```
std::vector<std::shared_ptr<Field>> bind_fields(input_schema->fields());
for (...) {
bind_fields.push_back(...);
}
std::shared_ptr<Schema> bind_schema = schema(std::move(bind_fields));
```
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -562,6 +563,33 @@ Result<std::shared_ptr<Schema>> ExtractSchemaToBind(const
compute::Declaration&
} else if (declr.factory_name == "filter") {
auto input_declr = std::get<compute::Declaration>(declr.inputs[0]);
ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr));
+ } else if (declr.factory_name == "project") {
+ auto input_declr = std::get<compute::Declaration>(declr.inputs[0]);
+ ARROW_ASSIGN_OR_RAISE(auto input_schema, ExtractSchemaToBind(input_declr));
+ const int num_fields_before_proj = input_schema->num_fields();
+ const auto& opts = checked_cast<const
compute::ProjectNodeOptions&>(*(declr.options));
+ const auto& exprs = opts.expressions;
+ int i = 0;
+ bind_schema = input_schema;
+ for (const auto& expr : exprs) {
+ std::shared_ptr<Field> project_field;
+ auto bound_expr = expr.Bind(*input_schema);
+ if (auto* expr_call = bound_expr->call()) {
+ project_field = field(expr_call->function_name,
+ expr_call->kernel->signature->out_type().type());
Review Comment:
Is there a `expr_call->type`?
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -2082,43 +2103,121 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) {
auto sink_decls,
DeserializePlans(
*serialized_plan, [] { return kNullConsumer; }, ext_id_reg,
&ext_set));
- // filter declaration
- auto& roundtripped_filter =
std::get<compute::Declaration>(sink_decls[0].inputs[0]);
- const auto& filter_opts =
- checked_cast<const
compute::FilterNodeOptions&>(*(roundtripped_filter.options));
- auto roundtripped_expr = filter_opts.filter_expression;
-
- if (auto* call = roundtripped_expr.call()) {
- EXPECT_EQ(call->function_name, "equal");
- auto args = call->arguments;
- auto left_index = args[0].field_ref()->field_path()->indices()[0];
- EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left);
- auto right_index = args[1].field_ref()->field_path()->indices()[0];
- EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right);
- }
- // scan declaration
+ // assert filter declaration
+ const auto& roundtripped_filter =
+ std::get<compute::Declaration>(sink_decls[0].inputs[0]);
+ AssertFilterRelation(roundtripped_filter, std::move(filter_expr),
dummy_schema);
+ // assert scan declaration
const auto& roundtripped_scan =
std::get<compute::Declaration>(roundtripped_filter.inputs[0]);
- const auto& dataset_opts =
- checked_cast<const
dataset::ScanNodeOptions&>(*(roundtripped_scan.options));
- const auto& roundripped_ds = dataset_opts.dataset;
- EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema));
- ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments());
- ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments());
+ AssertScanRelation(roundtripped_scan, dataset, dummy_schema);
+ // assert results
+ AssertPlanExecutionResult(expected_table, roundtripped_filter, dummy_schema,
+ exec_context);
+}
- auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs));
- auto expected_frg_vec = IteratorToVector(std::move(expected_frgs));
- EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size());
- int64_t idx = 0;
- for (auto fragment : expected_frg_vec) {
- const auto* l_frag = checked_cast<const
dataset::FileFragment*>(fragment.get());
- const auto* r_frag =
- checked_cast<const
dataset::FileFragment*>(roundtrip_frg_vec[idx++].get());
- EXPECT_TRUE(l_frag->Equals(*r_frag));
+TEST(Substrait, FilterProjectPlanRoundTripping) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#endif
Review Comment:
```suggestion
```
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -2082,43 +2103,121 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) {
auto sink_decls,
DeserializePlans(
*serialized_plan, [] { return kNullConsumer; }, ext_id_reg,
&ext_set));
- // filter declaration
- auto& roundtripped_filter =
std::get<compute::Declaration>(sink_decls[0].inputs[0]);
- const auto& filter_opts =
- checked_cast<const
compute::FilterNodeOptions&>(*(roundtripped_filter.options));
- auto roundtripped_expr = filter_opts.filter_expression;
-
- if (auto* call = roundtripped_expr.call()) {
- EXPECT_EQ(call->function_name, "equal");
- auto args = call->arguments;
- auto left_index = args[0].field_ref()->field_path()->indices()[0];
- EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left);
- auto right_index = args[1].field_ref()->field_path()->indices()[0];
- EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right);
- }
- // scan declaration
+ // assert filter declaration
+ const auto& roundtripped_filter =
+ std::get<compute::Declaration>(sink_decls[0].inputs[0]);
+ AssertFilterRelation(roundtripped_filter, std::move(filter_expr),
dummy_schema);
+ // assert scan declaration
const auto& roundtripped_scan =
std::get<compute::Declaration>(roundtripped_filter.inputs[0]);
- const auto& dataset_opts =
- checked_cast<const
dataset::ScanNodeOptions&>(*(roundtripped_scan.options));
- const auto& roundripped_ds = dataset_opts.dataset;
- EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema));
- ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments());
- ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments());
+ AssertScanRelation(roundtripped_scan, dataset, dummy_schema);
+ // assert results
+ AssertPlanExecutionResult(expected_table, roundtripped_filter, dummy_schema,
+ exec_context);
+}
- auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs));
- auto expected_frg_vec = IteratorToVector(std::move(expected_frgs));
- EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size());
- int64_t idx = 0;
- for (auto fragment : expected_frg_vec) {
- const auto* l_frag = checked_cast<const
dataset::FileFragment*>(fragment.get());
- const auto* r_frag =
- checked_cast<const
dataset::FileFragment*>(roundtrip_frg_vec[idx++].get());
- EXPECT_TRUE(l_frag->Equals(*r_frag));
+TEST(Substrait, FilterProjectPlanRoundTripping) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#endif
+ compute::ExecContext exec_context;
+ arrow::dataset::internal::Initialize();
+
+ auto dummy_schema = schema(
+ {field("key", int32()), field("shared", int32()), field("distinct",
int32())});
+
+ // creating a dummy dataset using a dummy table
+ auto table = TableFromJSON(dummy_schema, {R"([
+ [1, 1, 10],
+ [3, 4, 20]
+ ])",
+ R"([
+ [0, 2, 1],
+ [1, 3, 2],
+ [4, 1, 3],
+ [3, 1, 3],
+ [1, 2, 5]
+ ])",
+ R"([
+ [2, 2, 12],
+ [5, 3, 12],
+ [1, 3, 12]
+ ])"});
+
+ auto format = std::make_shared<arrow::dataset::IpcFileFormat>();
+ auto filesystem = std::make_shared<fs::LocalFileSystem>();
+ const std::string file_name = "serde_project_test.arrow";
+
+ ASSERT_OK_AND_ASSIGN(auto tempdir,
+
arrow::internal::TemporaryDir::Make("substrait-tempdir-project-"));
+ ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name));
+ std::string file_path_str = file_path.ToString();
+
+ WriteIpcData(file_path_str, filesystem, table);
+
+ std::vector<fs::FileInfo> files;
+ const std::vector<std::string> f_paths = {file_path_str};
+
+ for (const auto& f_path : f_paths) {
+ ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path));
+ files.push_back(std::move(f_file));
}
- ASSERT_OK_AND_ASSIGN(auto rnd_trp_table,
- GetTableFromPlan(roundtripped_filter, exec_context,
dummy_schema));
- EXPECT_TRUE(expected_table->Equals(*rnd_trp_table));
+
+ ASSERT_OK_AND_ASSIGN(auto ds_factory,
dataset::FileSystemDatasetFactory::Make(
+ filesystem, std::move(files),
format, {}));
+ ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema));
+
+ auto scan_options = std::make_shared<dataset::ScanOptions>();
+ scan_options->projection = compute::project({}, {});
+ compute::Expression project_fp_expr = compute::call(
+ "add", {compute::field_ref(FieldPath({1})),
compute::field_ref(FieldPath({2}))});
+
+ std::vector<compute::Expression> project_fp_exprs = {project_fp_expr};
+
+ auto declarations = compute::Declaration::Sequence(
+ {compute::Declaration(
+ {"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}),
+ compute::Declaration(
+ {"project", compute::ProjectNodeOptions{project_fp_exprs}, "p"})});
+
+ // adding the project expression field to schema
+ auto project_schema = schema({field("add(shared,distinct)", int32())});
+
+ ASSERT_OK_AND_ASSIGN(auto expected_table,
+ GetTableFromPlan(declarations, exec_context,
project_schema));
+
+ std::shared_ptr<ExtensionIdRegistry> sp_ext_id_reg =
MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ ExtensionSet ext_set(ext_id_reg);
+
+ ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations,
&ext_set));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto sink_decls,
+ DeserializePlans(
+ *serialized_plan, [] { return kNullConsumer; }, ext_id_reg,
&ext_set));
+ // assert project declaration
+ const auto& roundtripped_emit_project =
+ std::get<compute::Declaration>(sink_decls[0].inputs[0]);
+
+ // assert emit projection declaration
+ // emit node used in the serialization outputs the 4th column since the
+ // 0, 1, 2 original columns from the input data and 4th column being the
+ // projected column.
+ auto expec_project_expr = {compute::field_ref(FieldPath({4}))};
Review Comment:
Right now I think our Substrait->Acero path does not special case the
"project" relation when handling emit. In other words, converting
Substrait->Acero could convert one project node into two back-to-back project
nodes. So I don't know that project will ever round trip. What I would expect
in this case (I think) is two project nodes.
First project node (four expressions):
field(0), field(1), field(2), field(1) + field(2)
Second project node (one expression):
field(3)
I'm not sure why it is `4` though instead of `3`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]