westonpace commented on code in PR #13914:
URL: https://github.com/apache/arrow/pull/13914#discussion_r966404968


##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -34,11 +34,53 @@ namespace arrow {
 namespace engine {
 
 template <typename RelMessage>
-Status CheckRelCommon(const RelMessage& rel) {
+bool HasEmit(const RelMessage& rel) {
+  if (rel.has_common()) {
+    return rel.common().has_emit();

Review Comment:
   For consistency this should probably be `switch 
(rel.common().emit_kind_case())`.
   
   Actually...is this function still used?



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -94,9 +135,11 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
         const substrait::ReadRel::NamedTable& named_table = read.named_table();
         std::vector<std::string> table_names(named_table.names().begin(),
                                              named_table.names().end());
-        ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl,
+        ARROW_ASSIGN_OR_RAISE(compute::Declaration no_emit_declaration,

Review Comment:
   Can you revert this change?  From the perspective of `FromProto` I think 
this is still the "source_decl".  The name "no_emit_declaration" only really 
makes sense inside `ProcessEmit`.  If you need to differentiate I think you 
could have `source_no_emit` or `source_before_emit` but I'm not sure that is 
really needed.  The old name was fine I think.



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -244,23 +290,48 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
       // NOTE: Substrait ProjectRels *append* columns, while Acero's project 
node replaces
       // them. Therefore, we need to prefix all the current columns for 
compatibility.
       std::vector<compute::Expression> expressions;
-      expressions.reserve(input.num_columns + project.expressions().size());
-      for (int i = 0; i < input.num_columns; i++) {
+      int num_columns = input.output_schema->num_fields();
+      expressions.reserve(num_columns + project.expressions().size());
+      for (int i = 0; i < num_columns; i++) {
         expressions.emplace_back(compute::field_ref(FieldRef(i)));
       }
+
+      int i = 0;
+      auto project_schema = input.output_schema;
       for (const auto& expr : project.expressions()) {
-        expressions.emplace_back();
-        ARROW_ASSIGN_OR_RAISE(expressions.back(),
+        std::shared_ptr<Field> project_field;
+        ARROW_ASSIGN_OR_RAISE(compute::Expression des_expr,
                               FromProto(expr, ext_set, conversion_options));
+        auto bound_expr = des_expr.Bind(*input.output_schema);
+        if (auto* expr_call = bound_expr->call()) {
+          project_field = field(expr_call->function_name,
+                                
expr_call->kernel->signature->out_type().type());
+        } else if (auto* field_ref = des_expr.field_ref()) {
+          ARROW_ASSIGN_OR_RAISE(FieldPath field_path,
+                                field_ref->FindOne(*input.output_schema));
+          ARROW_ASSIGN_OR_RAISE(project_field, 
field_path.Get(*input.output_schema));
+        } else if (auto* literal = des_expr.literal()) {
+          project_field =
+              field("field_" + std::to_string(num_columns + i), 
literal->type());
+        }
+        ARROW_ASSIGN_OR_RAISE(
+            project_schema,
+            project_schema->AddField(
+                num_columns + static_cast<int>(project.expressions().size()) - 
1,
+                std::move(project_field)));
+        i++;
+        expressions.emplace_back(des_expr);
       }
 
-      auto num_columns = static_cast<int>(expressions.size());
-      return DeclarationInfo{
+      DeclarationInfo no_emit_declaration{
           compute::Declaration::Sequence({
               std::move(input.declaration),
               {"project", compute::ProjectNodeOptions{std::move(expressions)}},
           }),
-          num_columns};
+          project_schema};
+
+      return ProcessEmit(std::move(project), std::move(no_emit_declaration),

Review Comment:
   ```suggestion
         return ProcessEmit(std::move(project), std::move(project_declaration),
   ```



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -399,17 +495,38 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
               ExtensionIdRegistry::SubstraitAggregateToArrow converter,
               
ext_set.registry()->GetSubstraitAggregateToArrow(aggregate_call.id()));
           ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, 
converter(aggregate_call));
+
+          // find aggregate field ids from schema
+          const auto field_ref = arrow_agg.target;
+          ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
+          agg_src_field_ids[measure_id] = match[0];
+
           aggregates.push_back(std::move(arrow_agg));
         } else {
           return Status::Invalid("substrait::AggregateFunction not provided");
         }
       }
+      FieldVector output_fields;
+      output_fields.reserve(key_field_ids.size() + agg_src_field_ids.size());
+      // extract aggregate fields to output schema
+      for (int id = 0; id < static_cast<int>(agg_src_field_ids.size()); id++) {
+        output_fields.emplace_back(input_schema->field(agg_src_field_ids[id]));
+      }
+      // extract key fields to output schema
+      for (int id = 0; id < static_cast<int>(key_field_ids.size()); id++) {
+        output_fields.emplace_back(input_schema->field(key_field_ids[id]));
+      }
 
-      return DeclarationInfo{
+      std::shared_ptr<Schema> aggregate_schema = 
schema(std::move(output_fields));
+
+      DeclarationInfo no_emit_declaration{

Review Comment:
   ```suggestion
         DeclarationInfo aggregate_declaration{
   ```



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -94,9 +135,11 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
         const substrait::ReadRel::NamedTable& named_table = read.named_table();
         std::vector<std::string> table_names(named_table.names().begin(),
                                              named_table.names().end());
-        ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl,
+        ARROW_ASSIGN_OR_RAISE(compute::Declaration no_emit_declaration,
                               named_table_provider(table_names));
-        return DeclarationInfo{std::move(source_decl), num_columns};
+        return ProcessEmit(std::move(read),

Review Comment:
   Same, revert to `source_decl`.



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -338,15 +409,29 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
       if (!left_keys || !right_keys) {
         return Status::Invalid("Left keys for join cannot be null");
       }
+
+      // Create output schema from left, right relations and join keys
+      std::shared_ptr<Schema> join_schema = left.output_schema;
+      std::shared_ptr<Schema> right_schema = right.output_schema;
+
+      for (const auto& field : right_schema->fields()) {
+        ARROW_ASSIGN_OR_RAISE(
+            join_schema, join_schema->AddField(
+                             static_cast<int>(join_schema->fields().size()) - 
1, field));
+      }

Review Comment:
   Each time you call `AddField` it is going to create a new schema, create a 
new vector of fields, and copy all the fields over.  So this ends up having 
O(n^2) complexity.  Instead I think you could do something like...
   
   ```
   FieldVector combined_fields = left.output_schema.fields();
   const FieldVector& right_fields = right.output_schema.fields();
   combined_fields.insert(combined_fields.end(), right_fields.begin(), 
right_fields.end());
   std::shared_ptr<Schema> join_schema = schema(std::move(combined_fields));
   ```



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -222,19 +267,20 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
       }
       ARROW_ASSIGN_OR_RAISE(auto condition,
                             FromProto(filter.condition(), ext_set, 
conversion_options));
-
-      return DeclarationInfo{
+      DeclarationInfo no_emit_declaration{
           compute::Declaration::Sequence({
               std::move(input.declaration),
               {"filter", compute::FilterNodeOptions{std::move(condition)}},
           }),
-          input.num_columns};
+          input.output_schema};
+
+      return ProcessEmit(std::move(filter), std::move(no_emit_declaration),

Review Comment:
   ```suggestion
         return ProcessEmit(std::move(filter), std::move(filter_declaration),
   ```



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -199,12 +242,14 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
                                                  std::move(filesystem), 
std::move(files),
                                                  std::move(format), {}));
 
-      ARROW_ASSIGN_OR_RAISE(auto ds, 
ds_factory->Finish(std::move(base_schema)));
+      ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(base_schema));
 
-      return DeclarationInfo{
-          compute::Declaration{
-              "scan", dataset::ScanNodeOptions{std::move(ds), 
std::move(scan_options)}},
-          num_columns};
+      DeclarationInfo no_emit_declaration = {
+          compute::Declaration{"scan", dataset::ScanNodeOptions{ds, 
scan_options}},
+          base_schema};
+
+      return ProcessEmit(std::move(read), std::move(no_emit_declaration),
+                         std::move(base_schema));

Review Comment:
   ```suggestion
         DeclarationInfo scan_declaration = {
             compute::Declaration{"scan", dataset::ScanNodeOptions{ds, 
scan_options}},
             base_schema};
   
         return ProcessEmit(std::move(read), std::move(scan_declaration),
                            std::move(base_schema));
   ```



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -338,15 +409,29 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
       if (!left_keys || !right_keys) {
         return Status::Invalid("Left keys for join cannot be null");
       }
+
+      // Create output schema from left, right relations and join keys
+      std::shared_ptr<Schema> join_schema = left.output_schema;
+      std::shared_ptr<Schema> right_schema = right.output_schema;
+
+      for (const auto& field : right_schema->fields()) {
+        ARROW_ASSIGN_OR_RAISE(
+            join_schema, join_schema->AddField(
+                             static_cast<int>(join_schema->fields().size()) - 
1, field));
+      }
+
       compute::HashJoinNodeOptions join_options{{std::move(*left_keys)},
                                                 {std::move(*right_keys)}};
       join_options.join_type = join_type;
       join_options.key_cmp = {join_key_cmp};
       compute::Declaration join_dec{"hashjoin", std::move(join_options)};
-      auto num_columns = left.num_columns + right.num_columns;
       join_dec.inputs.emplace_back(std::move(left.declaration));
       join_dec.inputs.emplace_back(std::move(right.declaration));
-      return DeclarationInfo{std::move(join_dec), num_columns};
+
+      DeclarationInfo no_emit_declaration{std::move(join_dec), join_schema};
+
+      return ProcessEmit(std::move(join), std::move(no_emit_declaration),
+                         std::move(join_schema));

Review Comment:
   ```suggestion
         DeclarationInfo join_declaration{std::move(join_dec), join_schema};
   
         return ProcessEmit(std::move(join), std::move(join_declaration),
                            std::move(join_schema));
   ```



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -244,23 +290,48 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
       // NOTE: Substrait ProjectRels *append* columns, while Acero's project 
node replaces
       // them. Therefore, we need to prefix all the current columns for 
compatibility.
       std::vector<compute::Expression> expressions;
-      expressions.reserve(input.num_columns + project.expressions().size());
-      for (int i = 0; i < input.num_columns; i++) {
+      int num_columns = input.output_schema->num_fields();
+      expressions.reserve(num_columns + project.expressions().size());
+      for (int i = 0; i < num_columns; i++) {
         expressions.emplace_back(compute::field_ref(FieldRef(i)));
       }
+
+      int i = 0;
+      auto project_schema = input.output_schema;
       for (const auto& expr : project.expressions()) {
-        expressions.emplace_back();
-        ARROW_ASSIGN_OR_RAISE(expressions.back(),
+        std::shared_ptr<Field> project_field;
+        ARROW_ASSIGN_OR_RAISE(compute::Expression des_expr,
                               FromProto(expr, ext_set, conversion_options));
+        auto bound_expr = des_expr.Bind(*input.output_schema);
+        if (auto* expr_call = bound_expr->call()) {
+          project_field = field(expr_call->function_name,
+                                
expr_call->kernel->signature->out_type().type());
+        } else if (auto* field_ref = des_expr.field_ref()) {
+          ARROW_ASSIGN_OR_RAISE(FieldPath field_path,
+                                field_ref->FindOne(*input.output_schema));
+          ARROW_ASSIGN_OR_RAISE(project_field, 
field_path.Get(*input.output_schema));
+        } else if (auto* literal = des_expr.literal()) {
+          project_field =
+              field("field_" + std::to_string(num_columns + i), 
literal->type());
+        }
+        ARROW_ASSIGN_OR_RAISE(
+            project_schema,
+            project_schema->AddField(
+                num_columns + static_cast<int>(project.expressions().size()) - 
1,
+                std::move(project_field)));
+        i++;
+        expressions.emplace_back(des_expr);
       }
 
-      auto num_columns = static_cast<int>(expressions.size());
-      return DeclarationInfo{
+      DeclarationInfo no_emit_declaration{

Review Comment:
   ```suggestion
         DeclarationInfo project_declaration{
   ```



##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -45,6 +54,31 @@ using internal::checked_cast;
 
 namespace engine {
 
+Status WriteParquetData(const std::string& path,
+                        const std::shared_ptr<fs::FileSystem> file_system,
+                        const std::shared_ptr<Table> input) {
+  EXPECT_OK_AND_ASSIGN(auto buffer_writer, 
file_system->OpenOutputStream(path));

Review Comment:
   Generally a test method should either return a `Status` and not have any 
calls to `EXPECT_...` or it should return a value / void.
   
   A good rule of thumb is:
   
   Test case - Use ASSERT_, return void
   Test helper method, used only within one file - Use ASSERT_, return void or 
use EXPECT_, return value
   Test helper method, shared by multiple test files - Use ARROW_, return 
Status or Result
   
   



##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -45,6 +54,31 @@ using internal::checked_cast;
 
 namespace engine {
 
+Status WriteParquetData(const std::string& path,
+                        const std::shared_ptr<fs::FileSystem> file_system,
+                        const std::shared_ptr<Table> input) {
+  EXPECT_OK_AND_ASSIGN(auto buffer_writer, 
file_system->OpenOutputStream(path));
+  PARQUET_THROW_NOT_OK(parquet::arrow::WriteTable(*input, 
arrow::default_memory_pool(),

Review Comment:
   We are not in the `parquet::` namespace so it is not acceptable to throw.  
Best to use `ARROW_RETURN_NOT_OK` instead or if you continue to use expect then 
use `EXPECT_OK`



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -399,17 +495,38 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
               ExtensionIdRegistry::SubstraitAggregateToArrow converter,
               
ext_set.registry()->GetSubstraitAggregateToArrow(aggregate_call.id()));
           ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, 
converter(aggregate_call));
+
+          // find aggregate field ids from schema
+          const auto field_ref = arrow_agg.target;
+          ARROW_ASSIGN_OR_RAISE(auto match, field_ref.FindOne(*input_schema));
+          agg_src_field_ids[measure_id] = match[0];
+
           aggregates.push_back(std::move(arrow_agg));
         } else {
           return Status::Invalid("substrait::AggregateFunction not provided");
         }
       }
+      FieldVector output_fields;
+      output_fields.reserve(key_field_ids.size() + agg_src_field_ids.size());
+      // extract aggregate fields to output schema
+      for (int id = 0; id < static_cast<int>(agg_src_field_ids.size()); id++) {
+        output_fields.emplace_back(input_schema->field(agg_src_field_ids[id]));
+      }
+      // extract key fields to output schema
+      for (int id = 0; id < static_cast<int>(key_field_ids.size()); id++) {
+        output_fields.emplace_back(input_schema->field(key_field_ids[id]));
+      }
 
-      return DeclarationInfo{
+      std::shared_ptr<Schema> aggregate_schema = 
schema(std::move(output_fields));
+
+      DeclarationInfo no_emit_declaration{
           compute::Declaration::Sequence(
               {std::move(input.declaration),
                {"aggregate", compute::AggregateNodeOptions{aggregates, 
keys}}}),
-          static_cast<int>(aggregates.size())};
+          aggregate_schema};
+
+      return ProcessEmit(std::move(aggregate), std::move(no_emit_declaration),

Review Comment:
   ```suggestion
         return ProcessEmit(std::move(aggregate), 
std::move(aggregate_declaration),
   ```



##########
cpp/src/arrow/engine/substrait/relation_internal.cc:
##########
@@ -222,19 +267,20 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& 
rel, const ExtensionSet&
       }
       ARROW_ASSIGN_OR_RAISE(auto condition,
                             FromProto(filter.condition(), ext_set, 
conversion_options));
-
-      return DeclarationInfo{
+      DeclarationInfo no_emit_declaration{

Review Comment:
   ```suggestion
         DeclarationInfo filter_declaration{
   ```



##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -121,6 +155,57 @@ inline compute::Expression UseBoringRefs(const 
compute::Expression& expr) {
   return compute::Expression{std::move(modified_call)};
 }
 
+// TODO: complete this interface

Review Comment:
   Does this still need completed?  Is there a follow-up JIRA?  What's the plan 
here?



##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -121,6 +155,57 @@ inline compute::Expression UseBoringRefs(const 
compute::Expression& expr) {
   return compute::Expression{std::move(modified_call)};
 }
 
+// TODO: complete this interface
+struct TempDataGenerator {
+  TempDataGenerator(const std::shared_ptr<Table> input_table,
+                    const std::string& file_prefix,
+                    std::unique_ptr<arrow::internal::TemporaryDir>& tempdir)
+      : input_table(input_table), file_prefix(file_prefix), tempdir(tempdir) {}
+
+  Status operator()() {
+    auto format = std::make_shared<arrow::dataset::ParquetFileFormat>();
+    auto filesystem = std::make_shared<fs::LocalFileSystem>();
+    const std::string file_name = file_prefix + ".parquet";
+    ARROW_ASSIGN_OR_RAISE(auto file_path, tempdir->path().Join(file_name));
+    data_file_path = file_path.ToString();
+    ARROW_EXPECT_OK(WriteParquetData(data_file_path, filesystem, input_table));
+    return Status::OK();
+  }
+
+  std::shared_ptr<Table> input_table;
+  std::string file_prefix;
+  std::unique_ptr<arrow::internal::TemporaryDir>& tempdir;
+  std::string data_file_path;
+};
+
+void CheckRoundTripResult(const std::shared_ptr<Schema> output_schema,
+                          const std::shared_ptr<Table> expected_table,
+                          compute::ExecContext& exec_context,
+                          std::shared_ptr<Buffer>& buf,
+                          const std::vector<int>& include_columns = {},
+                          const ConversionOptions& conversion_options = {}) {
+  std::shared_ptr<ExtensionIdRegistry> sp_ext_id_reg = 
MakeExtensionIdRegistry();
+  ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get();
+  ExtensionSet ext_set(ext_id_reg);
+  ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans(
+                                            *buf, [] { return kNullConsumer; },
+                                            ext_id_reg, &ext_set, 
conversion_options));
+  auto other_declrs = sink_decls[0].inputs[0].get<compute::Declaration>();
+  arrow::AsyncGenerator<util::optional<compute::ExecBatch>> sink_gen;
+  auto sink_node_options = compute::SinkNodeOptions{&sink_gen};
+  auto sink_declaration = compute::Declaration({"sink", sink_node_options, 
"e"});
+  auto declarations = compute::Declaration::Sequence({*other_declrs, 
sink_declaration});
+  ASSERT_OK_AND_ASSIGN(auto acero_plan, 
compute::ExecPlan::Make(&exec_context));

Review Comment:
   You should move these lines into `GetTableFromPlan` (maybe change it to 
`GetTableFromDeclaration`)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to