This is an automated email from the ASF dual-hosted git repository.

westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 24e5a580f9 GH-33899: [C++] Add NamedTapRel relation as a Substrait 
extension (#33909)
24e5a580f9 is described below

commit 24e5a580f907a2a995fb9183cb4cce6218c711e6
Author: rtpsw <[email protected]>
AuthorDate: Fri Feb 10 22:14:43 2023 +0200

    GH-33899: [C++] Add NamedTapRel relation as a Substrait extension (#33909)
    
    See #33899. This PR adds `NamedTapRel` and a simple test case with a no-op 
tap (i.e., just passing-through).
    * Closes: #33899
    
    Lead-authored-by: Yaron Gvili <[email protected]>
    Co-authored-by: Weston Pace <[email protected]>
    Signed-off-by: Weston Pace <[email protected]>
---
 cpp/proto/substrait/extension_rels.proto           |  14 +++
 cpp/src/arrow/engine/substrait/options.cc          |  84 ++++++++++++++--
 cpp/src/arrow/engine/substrait/options.h           |  17 +++-
 .../arrow/engine/substrait/relation_internal.cc    |   6 +-
 cpp/src/arrow/engine/substrait/serde_test.cc       | 112 +++++++++++++++++++++
 cpp/src/arrow/type.cc                              |  15 +++
 cpp/src/arrow/type.h                               |   6 ++
 7 files changed, 242 insertions(+), 12 deletions(-)

diff --git a/cpp/proto/substrait/extension_rels.proto 
b/cpp/proto/substrait/extension_rels.proto
index 0392405f80..78c11b7d7e 100644
--- a/cpp/proto/substrait/extension_rels.proto
+++ b/cpp/proto/substrait/extension_rels.proto
@@ -44,3 +44,17 @@ message AsOfJoinRel {
     repeated .substrait.Expression by = 2;
   }
 }
+
+// Named tap relation
+//
+// A tap is a relation having a single input relation that it passes through, 
while also
+// causing some side-effect, e.g., writing to external storage.
+message NamedTapRel {
+  // The kind of tap
+  string kind = 1;
+  // A name used to configure the tap, e.g., a URI defining the destination of 
writing
+  string name = 2;
+  // Column names for the tap's output. If specified there must be one name 
per field.
+  // If empty, field names will be automatically generated.
+  repeated string columns = 3;
+}
diff --git a/cpp/src/arrow/engine/substrait/options.cc 
b/cpp/src/arrow/engine/substrait/options.cc
index be23ce1e64..b4b10a021d 100644
--- a/cpp/src/arrow/engine/substrait/options.cc
+++ b/cpp/src/arrow/engine/substrait/options.cc
@@ -30,23 +30,39 @@
 namespace arrow {
 namespace engine {
 
+namespace {
+
+std::vector<compute::Declaration::Input> MakeDeclarationInputs(
+    const std::vector<DeclarationInfo>& inputs) {
+  std::vector<compute::Declaration::Input> input_decls(inputs.size());
+  for (size_t i = 0; i < inputs.size(); i++) {
+    input_decls[i] = inputs[i].declaration;
+  }
+  return input_decls;
+}
+
+}  // namespace
+
 class BaseExtensionProvider : public ExtensionProvider {
  public:
-  Result<RelationInfo> MakeRel(const std::vector<DeclarationInfo>& inputs,
+  Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
+                               const std::vector<DeclarationInfo>& inputs,
                                const ExtensionDetails& ext_details,
                                const ExtensionSet& ext_set) override {
     auto details = dynamic_cast<const DefaultExtensionDetails&>(ext_details);
-    return MakeRel(inputs, details.rel, ext_set);
+    return MakeRel(conv_opts, inputs, details.rel, ext_set);
   }
 
-  virtual Result<RelationInfo> MakeRel(const std::vector<DeclarationInfo>& 
inputs,
+  virtual Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
+                                       const std::vector<DeclarationInfo>& 
inputs,
                                        const google::protobuf::Any& rel,
                                        const ExtensionSet& ext_set) = 0;
 };
 
 class DefaultExtensionProvider : public BaseExtensionProvider {
  public:
-  Result<RelationInfo> MakeRel(const std::vector<DeclarationInfo>& inputs,
+  Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
+                               const std::vector<DeclarationInfo>& inputs,
                                const google::protobuf::Any& rel,
                                const ExtensionSet& ext_set) override {
     if (rel.Is<substrait_ext::AsOfJoinRel>()) {
@@ -54,6 +70,11 @@ class DefaultExtensionProvider : public 
BaseExtensionProvider {
       rel.UnpackTo(&as_of_join_rel);
       return MakeAsOfJoinRel(inputs, as_of_join_rel, ext_set);
     }
+    if (rel.Is<substrait_ext::NamedTapRel>()) {
+      substrait_ext::NamedTapRel named_tap_rel;
+      rel.UnpackTo(&named_tap_rel);
+      return MakeNamedTapRel(conv_opts, inputs, named_tap_rel, ext_set);
+    }
     return Status::NotImplemented("Unrecognized extension in Susbstrait plan: 
",
                                   rel.DebugString());
   }
@@ -113,15 +134,38 @@ class DefaultExtensionProvider : public 
BaseExtensionProvider {
     compute::AsofJoinNodeOptions asofjoin_node_opts{std::move(input_keys), 
tolerance};
 
     // declaration
-    std::vector<compute::Declaration::Input> input_decls(inputs.size());
-    for (size_t i = 0; i < inputs.size(); i++) {
-      input_decls[i] = inputs[i].declaration;
-    }
+    auto input_decls = MakeDeclarationInputs(inputs);
     return RelationInfo{
         {compute::Declaration("asofjoin", input_decls, 
std::move(asofjoin_node_opts)),
          std::move(schema)},
         std::move(field_output_indices)};
   }
+
+  Result<RelationInfo> MakeNamedTapRel(const ConversionOptions& conv_opts,
+                                       const std::vector<DeclarationInfo>& 
inputs,
+                                       const substrait_ext::NamedTapRel& 
named_tap_rel,
+                                       const ExtensionSet& ext_set) {
+    if (inputs.size() != 1) {
+      return Status::Invalid(
+          "substrait_ext::NamedTapRel requires a single input but got: ", 
inputs.size());
+    }
+
+    auto schema = inputs[0].output_schema;
+    int num_fields = schema->num_fields();
+    if (named_tap_rel.columns_size() != num_fields) {
+      return Status::Invalid("Got ", named_tap_rel.columns_size(),
+                             " NamedTapRel columns but expected ", num_fields);
+    }
+    std::vector<std::string> columns(named_tap_rel.columns().begin(),
+                                     named_tap_rel.columns().end());
+    ARROW_ASSIGN_OR_RAISE(auto renamed_schema, schema->WithNames(columns));
+    auto input_decls = MakeDeclarationInputs(inputs);
+    ARROW_ASSIGN_OR_RAISE(
+        auto decl,
+        conv_opts.named_tap_provider(named_tap_rel.kind(), input_decls,
+                                     named_tap_rel.name(), 
std::move(renamed_schema)));
+    return RelationInfo{{std::move(decl), std::move(renamed_schema)}, 
std::nullopt};
+  }
 };
 
 namespace {
@@ -143,5 +187,29 @@ void set_default_extension_provider(const 
std::shared_ptr<ExtensionProvider>& pr
   g_default_extension_provider = provider;
 }
 
+namespace {
+
+NamedTapProvider g_default_named_tap_provider =
+    [](const std::string& tap_kind, std::vector<compute::Declaration::Input> 
inputs,
+       const std::string& tap_name,
+       std::shared_ptr<Schema> tap_schema) -> Result<compute::Declaration> {
+  return Status::NotImplemented(
+      "Plan contained a NamedTapRel but no provider configured");
+};
+
+std::mutex g_default_named_tap_provider_mutex;
+
+}  // namespace
+
+NamedTapProvider default_named_tap_provider() {
+  std::unique_lock<std::mutex> lock(g_default_named_tap_provider_mutex);
+  return g_default_named_tap_provider;
+}
+
+void set_default_named_tap_provider(NamedTapProvider provider) {
+  std::unique_lock<std::mutex> lock(g_default_named_tap_provider_mutex);
+  g_default_named_tap_provider = provider;
+}
+
 }  // namespace engine
 }  // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/options.h 
b/cpp/src/arrow/engine/substrait/options.h
index 35a4f70aa9..3b4a6963ac 100644
--- a/cpp/src/arrow/engine/substrait/options.h
+++ b/cpp/src/arrow/engine/substrait/options.h
@@ -23,6 +23,8 @@
 #include <string>
 #include <vector>
 
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
 #include "arrow/compute/type_fwd.h"
 #include "arrow/engine/substrait/type_fwd.h"
 #include "arrow/engine/substrait/visibility.h"
@@ -67,6 +69,10 @@ using NamedTableProvider =
     std::function<Result<compute::Declaration>(const 
std::vector<std::string>&)>;
 static NamedTableProvider kDefaultNamedTableProvider;
 
+using NamedTapProvider = std::function<Result<compute::Declaration>(
+    const std::string&, std::vector<compute::Declaration::Input>, const 
std::string&,
+    std::shared_ptr<Schema>)>;
+
 class ARROW_ENGINE_EXPORT ExtensionDetails {
  public:
   virtual ~ExtensionDetails() = default;
@@ -75,7 +81,8 @@ class ARROW_ENGINE_EXPORT ExtensionDetails {
 class ARROW_ENGINE_EXPORT ExtensionProvider {
  public:
   virtual ~ExtensionProvider() = default;
-  virtual Result<RelationInfo> MakeRel(const std::vector<DeclarationInfo>& 
inputs,
+  virtual Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
+                                       const std::vector<DeclarationInfo>& 
inputs,
                                        const ExtensionDetails& ext_details,
                                        const ExtensionSet& ext_set) = 0;
 };
@@ -88,6 +95,10 @@ ARROW_ENGINE_EXPORT std::shared_ptr<ExtensionProvider> 
default_extension_provide
 ARROW_ENGINE_EXPORT void set_default_extension_provider(
     const std::shared_ptr<ExtensionProvider>& provider);
 
+ARROW_ENGINE_EXPORT NamedTapProvider default_named_tap_provider();
+
+ARROW_ENGINE_EXPORT void set_default_named_tap_provider(NamedTapProvider 
provider);
+
 /// Options that control the conversion between Substrait and Acero 
representations of a
 /// plan.
 struct ARROW_ENGINE_EXPORT ConversionOptions {
@@ -98,6 +109,10 @@ struct ARROW_ENGINE_EXPORT ConversionOptions {
   /// The default behavior will return an invalid status if the plan has any
   /// named table relations.
   NamedTableProvider named_table_provider = kDefaultNamedTableProvider;
+  /// \brief A custom strategy to be used for obtaining a tap declaration
+  ///
+  /// The default provider returns an error
+  NamedTapProvider named_tap_provider = default_named_tap_provider();
   std::shared_ptr<ExtensionProvider> extension_provider = 
default_extension_provider();
 };
 
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc 
b/cpp/src/arrow/engine/substrait/relation_internal.cc
index 4fb7bb2a78..19a38cd40e 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -206,7 +206,7 @@ Result<RelationInfo> GetExtensionRelationInfo(const 
substrait::Rel& rel,
     case substrait::Rel::RelTypeCase::kExtensionLeaf: {
       const auto& ext = rel.extension_leaf();
       DefaultExtensionDetails detail{ext.detail()};
-      return conv_opts.extension_provider->MakeRel(inputs, detail, ext_set);
+      return conv_opts.extension_provider->MakeRel(conv_opts, inputs, detail, 
ext_set);
     }
 
     case substrait::Rel::RelTypeCase::kExtensionSingle: {
@@ -215,7 +215,7 @@ Result<RelationInfo> GetExtensionRelationInfo(const 
substrait::Rel& rel,
                             FromProto(ext.input(), ext_set, conv_opts));
       inputs.push_back(std::move(input_info));
       DefaultExtensionDetails detail{ext.detail()};
-      return conv_opts.extension_provider->MakeRel(inputs, detail, ext_set);
+      return conv_opts.extension_provider->MakeRel(conv_opts, inputs, detail, 
ext_set);
     }
 
     case substrait::Rel::RelTypeCase::kExtensionMulti: {
@@ -225,7 +225,7 @@ Result<RelationInfo> GetExtensionRelationInfo(const 
substrait::Rel& rel,
         inputs.push_back(std::move(input_info));
       }
       DefaultExtensionDetails detail{ext.detail()};
-      return conv_opts.extension_provider->MakeRel(inputs, detail, ext_set);
+      return conv_opts.extension_provider->MakeRel(conv_opts, inputs, detail, 
ext_set);
     }
 
     default: {
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc 
b/cpp/src/arrow/engine/substrait/serde_test.cc
index d99d6fa0c4..97a2aea398 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -36,6 +36,7 @@
 #include "arrow/compute/exec/exec_plan.h"
 #include "arrow/compute/exec/expression.h"
 #include "arrow/compute/exec/expression_internal.h"
+#include "arrow/compute/exec/map_node.h"
 #include "arrow/compute/exec/options.h"
 #include "arrow/compute/exec/test_util.h"
 #include "arrow/compute/exec/util.h"
@@ -88,6 +89,30 @@ using internal::checked_cast;
 using internal::hash_combine;
 namespace engine {
 
+Status AddPassFactory(
+    const std::string& factory_name,
+    compute::ExecFactoryRegistry* registry = 
compute::default_exec_factory_registry()) {
+  using compute::ExecBatch;
+  using compute::ExecNode;
+  using compute::ExecNodeOptions;
+  using compute::ExecPlan;
+  struct PassNode : public compute::MapNode {
+    static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> 
inputs,
+                                  const compute::ExecNodeOptions& options) {
+      RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "PassNode"));
+      return plan->EmplaceNode<PassNode>(plan, inputs, 
inputs[0]->output_schema());
+    }
+
+    PassNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
+             std::shared_ptr<Schema> output_schema)
+        : MapNode(plan, inputs, output_schema) {}
+
+    const char* kind_name() const override { return "PassNode"; }
+    Result<ExecBatch> ProcessBatch(ExecBatch batch) override { return batch; }
+  };
+  return registry->AddFactory(factory_name, PassNode::Make);
+}
+
 const auto kNullConsumer = std::make_shared<compute::NullSinkNodeConsumer>();
 
 void WriteIpcData(const std::string& path,
@@ -5355,5 +5380,92 @@ TEST(Substrait, AsOfJoinDefaultEmit) {
   CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options);
 }
 
+TEST(Substrait, PlanWithNamedTapExtension) {
+  // This demos an extension relation
+  std::string substrait_json = R"({
+    "extensionUris": [],
+    "extensions": [],
+    "relations": [{
+      "root": {
+        "input": {
+          "extension_multi": {
+            "inputs": [
+              {
+                "read": {
+                  "common": {
+                    "direct": {
+                    }
+                  },
+                  "baseSchema": {
+                    "names": ["time", "key", "value"],
+                    "struct": {
+                      "types": [
+                        {
+                          "i32": {
+                            "typeVariationReference": 0,
+                            "nullability": "NULLABILITY_NULLABLE"
+                          }
+                        },
+                        {
+                          "i32": {
+                            "typeVariationReference": 0,
+                            "nullability": "NULLABILITY_NULLABLE"
+                          }
+                        },
+                        {
+                          "fp64": {
+                            "typeVariationReference": 0,
+                            "nullability": "NULLABILITY_NULLABLE"
+                          }
+                        }
+                      ],
+                      "typeVariationReference": 0,
+                      "nullability": "NULLABILITY_REQUIRED"
+                    }
+                  },
+                  "namedTable": {
+                    "names": ["T"]
+                  }
+                }
+              }
+            ],
+            "detail": {
+              "@type": "/arrow.substrait_ext.NamedTapRel",
+              "kind" : "pass_for_named_tap",
+              "name" : "does_not_matter",
+              "columns" : ["pass_time", "pass_key", "pass_value"]
+            }
+          }
+        },
+        "names": ["t", "k", "v"]
+      }
+    }],
+    "expectedTypeUrls": []
+  })";
+
+  ASSERT_OK(AddPassFactory("pass_for_named_tap"));
+
+  std::shared_ptr<Schema> input_schema =
+      schema({field("time", int32()), field("key", int32()), field("value", 
float64())});
+  NamedTableProvider table_provider = AlwaysProvideSameTable(
+      TableFromJSON(input_schema, {"[[2, 1, 1.1], [4, 1, 2.1], [6, 2, 
3.1]]"}));
+  ConversionOptions conversion_options;
+  conversion_options.named_table_provider = std::move(table_provider);
+  conversion_options.named_tap_provider =
+      [](const std::string& tap_kind, std::vector<compute::Declaration::Input> 
inputs,
+         const std::string& tap_name,
+         std::shared_ptr<Schema> tap_schema) -> Result<compute::Declaration> {
+    return compute::Declaration{tap_kind, std::move(inputs), 
compute::ExecNodeOptions{}};
+  };
+
+  ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", 
substrait_json));
+
+  std::shared_ptr<Schema> output_schema =
+      schema({field("t", int32()), field("k", int32()), field("v", 
float64())});
+  auto expected_table =
+      TableFromJSON(output_schema, {"[[2, 1, 1.1], [4, 1, 2.1], [6, 2, 
3.1]]"});
+  CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options);
+}
+
 }  // namespace engine
 }  // namespace arrow
diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index 825a68e68c..0e3732db6e 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -1714,6 +1714,21 @@ bool Schema::HasDistinctFieldNames() const {
   return names.size() == fields.size();
 }
 
+Result<std::shared_ptr<Schema>> Schema::WithNames(
+    const std::vector<std::string>& names) const {
+  if (names.size() != impl_->fields_.size()) {
+    return Status::Invalid("attempted to rename schema with ", 
impl_->fields_.size(),
+                           " fields but only ", names.size(), " new names were 
given");
+  }
+  FieldVector new_fields;
+  new_fields.reserve(names.size());
+  auto names_itr = names.begin();
+  for (const auto& field : impl_->fields_) {
+    new_fields.push_back(field->WithName(*names_itr++));
+  }
+  return schema(std::move(new_fields));
+}
+
 std::shared_ptr<Schema> Schema::WithMetadata(
     const std::shared_ptr<const KeyValueMetadata>& metadata) const {
   return std::make_shared<Schema>(impl_->fields_, metadata);
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index cf58218a7e..4ea4796231 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -1968,6 +1968,12 @@ class ARROW_EXPORT Schema : public 
detail::Fingerprintable,
   Result<std::shared_ptr<Schema>> SetField(int i,
                                            const std::shared_ptr<Field>& 
field) const;
 
+  /// \brief Replace field names with new names
+  ///
+  /// \param[in] names new names
+  /// \return new Schema
+  Result<std::shared_ptr<Schema>> WithNames(const std::vector<std::string>& 
names) const;
+
   /// \brief Replace key-value metadata with new metadata
   ///
   /// \param[in] metadata new KeyValueMetadata

Reply via email to