westonpace commented on code in PR #14118:
URL: https://github.com/apache/arrow/pull/14118#discussion_r977396888
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -2082,43 +2103,135 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) {
auto sink_decls,
DeserializePlans(
*serialized_plan, [] { return kNullConsumer; }, ext_id_reg,
&ext_set));
- // filter declaration
- auto& roundtripped_filter =
std::get<compute::Declaration>(sink_decls[0].inputs[0]);
- const auto& filter_opts =
- checked_cast<const
compute::FilterNodeOptions&>(*(roundtripped_filter.options));
- auto roundtripped_expr = filter_opts.filter_expression;
-
- if (auto* call = roundtripped_expr.call()) {
- EXPECT_EQ(call->function_name, "equal");
- auto args = call->arguments;
- auto left_index = args[0].field_ref()->field_path()->indices()[0];
- EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left);
- auto right_index = args[1].field_ref()->field_path()->indices()[0];
- EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right);
- }
- // scan declaration
+ // assert filter declaration
+ const auto& roundtripped_filter =
+ std::get<compute::Declaration>(sink_decls[0].inputs[0]);
+ AssertFilterRelation(roundtripped_filter, std::move(filter_expr),
dummy_schema);
+ // assert scan declaration
const auto& roundtripped_scan =
std::get<compute::Declaration>(roundtripped_filter.inputs[0]);
- const auto& dataset_opts =
- checked_cast<const
dataset::ScanNodeOptions&>(*(roundtripped_scan.options));
- const auto& roundripped_ds = dataset_opts.dataset;
- EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema));
- ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments());
- ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments());
+ AssertScanRelation(roundtripped_scan, dataset, dummy_schema);
+ // assert results
+ AssertPlanExecutionResult(expected_table, roundtripped_filter, dummy_schema,
+ exec_context);
+}
- auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs));
- auto expected_frg_vec = IteratorToVector(std::move(expected_frgs));
- EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size());
- int64_t idx = 0;
- for (auto fragment : expected_frg_vec) {
- const auto* l_frag = checked_cast<const
dataset::FileFragment*>(fragment.get());
- const auto* r_frag =
- checked_cast<const
dataset::FileFragment*>(roundtrip_frg_vec[idx++].get());
- EXPECT_TRUE(l_frag->Equals(*r_frag));
+TEST(Substrait, FilterProjectPlanRoundTripping) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#endif
+ compute::ExecContext exec_context;
+ arrow::dataset::internal::Initialize();
+
+ auto dummy_schema = schema(
+ {field("key", int32()), field("shared", int32()), field("distinct",
int32())});
+
+ // creating a dummy dataset using a dummy table
+ auto table = TableFromJSON(dummy_schema, {R"([
+ [1, 1, 10],
+ [3, 4, 20]
+ ])",
+ R"([
+ [0, 2, 1],
+ [1, 3, 2],
+ [4, 1, 3],
+ [3, 1, 3],
+ [1, 2, 5]
+ ])",
+ R"([
+ [2, 2, 12],
+ [5, 3, 12],
+ [1, 3, 12]
+ ])"});
+
+ auto format = std::make_shared<arrow::dataset::IpcFileFormat>();
+ auto filesystem = std::make_shared<fs::LocalFileSystem>();
+ const std::string file_name = "serde_project_test.arrow";
+
+ ASSERT_OK_AND_ASSIGN(auto tempdir,
+
arrow::internal::TemporaryDir::Make("substrait-tempdir-project-"));
+ ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name));
+ std::string file_path_str = file_path.ToString();
+
+ WriteIpcData(file_path_str, filesystem, table);
+
+ std::vector<fs::FileInfo> files;
+ const std::vector<std::string> f_paths = {file_path_str};
+
+ for (const auto& f_path : f_paths) {
+ ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path));
+ files.push_back(std::move(f_file));
}
- ASSERT_OK_AND_ASSIGN(auto rnd_trp_table,
- GetTableFromPlan(roundtripped_filter, exec_context,
dummy_schema));
- EXPECT_TRUE(expected_table->Equals(*rnd_trp_table));
+
+ ASSERT_OK_AND_ASSIGN(auto ds_factory,
dataset::FileSystemDatasetFactory::Make(
+ filesystem, std::move(files),
format, {}));
+ ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema));
+
+ auto scan_options = std::make_shared<dataset::ScanOptions>();
+ scan_options->projection = compute::project({}, {});
+ compute::Expression project_fp_expr = compute::call(
+ "add", {compute::field_ref(FieldPath({1})),
compute::field_ref(FieldPath({2}))});
+ // Acero only outputs the expressions mentioned in the projection but
Substrait
+ // outputs existing columns plus the expression.
+ // For validation purpose, a project expression with expected projection
+ // field plus existing fields is created for Acero.
+ // And for Substrait just the project expression with expected projection.
+ std::vector<compute::Expression> acero_project_fp_exprs = {
+ compute::field_ref(FieldPath({0})), compute::field_ref(FieldPath({1})),
+ compute::field_ref(FieldPath({2})), project_fp_expr};
+ std::vector<compute::Expression> substrait_project_fp_exprs =
{project_fp_expr};
+
+ auto acero_declarations = compute::Declaration::Sequence(
+ {compute::Declaration(
+ {"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}),
+ compute::Declaration(
+ {"project", compute::ProjectNodeOptions{acero_project_fp_exprs},
"p"})});
+
+ // adding the project expression field to schema
+ ASSERT_OK_AND_ASSIGN(auto project_schema,
+ dummy_schema->AddField(dummy_schema->num_fields(),
+ field("add(shared,distinct)",
int32())));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto expected_table,
+ GetTableFromPlan(acero_declarations, exec_context, project_schema));
+
+ std::shared_ptr<ExtensionIdRegistry> sp_ext_id_reg =
MakeExtensionIdRegistry();
+ ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+ ExtensionSet ext_set(ext_id_reg);
+
+ // declaration which Substrait would expect since the idea is to do the
projection
+ // which would produce an equivalent result to the declaration created for
Acero with
+ // additional project fields.
+ auto substrait_declarations = compute::Declaration::Sequence(
+ {compute::Declaration(
+ {"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}),
+ compute::Declaration(
+ {"project",
compute::ProjectNodeOptions{substrait_project_fp_exprs}, "p"})});
+
+ ASSERT_OK_AND_ASSIGN(auto serialized_plan,
+ SerializePlan(substrait_declarations, &ext_set));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto sink_decls,
+ DeserializePlans(
+ *serialized_plan, [] { return kNullConsumer; }, ext_id_reg,
&ext_set));
+ // assert project declaration
+ const auto& roundtripped_project =
+ std::get<compute::Declaration>(sink_decls[0].inputs[0]);
+ // assert project declaration
+ // Note: the provided expressions for Substrait declaration only contains one
+ // expression, but substrait produces expressions for the existing number of
fields plus
+ // provided expression. Since the output expressions from the deserialized
relation
+ // contains fields which weren't used in the project expression.
+ AssertProjectRelation(roundtripped_project, acero_project_fp_exprs,
project_schema);
Review Comment:
I think this is wrong. However, if you use my fix regarding emit up above
you will probably still have to play games here because you will end up with
two project nodes (since we deserialize an emit as a dedicated project node).
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -557,7 +557,7 @@ Result<std::shared_ptr<Schema>> ExtractSchemaToBind(const
compute::Declaration&
if (declr.factory_name == "scan") {
const auto& opts = checked_cast<const
dataset::ScanNodeOptions&>(*(declr.options));
bind_schema = opts.dataset->schema();
- } else if (declr.factory_name == "filter") {
+ } else if (declr.factory_name == "filter" || declr.factory_name ==
"project") {
Review Comment:
Is this right? Doesn't a `project` also add new columns to the schema?
##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -643,6 +643,36 @@ Result<std::unique_ptr<substrait::FilterRel>>
FilterRelationConverter(
return std::move(filter_rel);
}
+Result<std::unique_ptr<substrait::ProjectRel>> ProjectRelationConverter(
+ const std::shared_ptr<Schema>& schema, const compute::Declaration&
declaration,
+ ExtensionSet* ext_set, const ConversionOptions& conversion_options) {
+ auto project_rel = make_unique<substrait::ProjectRel>();
+ const auto& project_node_options =
+ checked_cast<const compute::ProjectNodeOptions&>(*declaration.options);
+
+ if (declaration.inputs.size() == 0) {
+ return Status::Invalid("Project node doesn't have an input.");
+ }
+
+ // handling input
+ auto declr_input = declaration.inputs[0];
+ ARROW_ASSIGN_OR_RAISE(
+ auto input_rel,
+ ToProto(std::get<compute::Declaration>(declr_input), ext_set,
conversion_options));
+
+ for (const auto& expr : project_node_options.expressions) {
+ compute::Expression bound_expression;
+ if (!expr.IsBound()) {
+ ARROW_ASSIGN_OR_RAISE(bound_expression, expr.Bind(*schema));
+ }
+ ARROW_ASSIGN_OR_RAISE(auto subs_expr,
+ ToProto(bound_expression, ext_set,
conversion_options));
+ project_rel->mutable_expressions()->AddAllocated(subs_expr.release());
+ }
Review Comment:
An Acero project potentially removes or modifies existing columns. Probably
the simplest thing to do would be to keep doing what you are doing here but
also add an emit to `project_rel` which removes all of the input columns.
E.g. if you had three input columns and the expressions were `field_ref(0)`
and `field_ref(1) + field_ref(2)` then you should have emit `[3, 4]` and
expressions `[field_ref(0), (field_ref(1) + field_ref(2))]`.
In theory you could also have emit `[0, 3]` and expressions `[field_ref(1) +
field_ref(2)]` but I think it would be more work than it is worth at the moment
to try and detect "pure references" (e.g. expressions that don't modify the
input column).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]