This is an automated email from the ASF dual-hosted git repository.
westonpace pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 24e5a580f9 GH-33899: [C++] Add NamedTapRel relation as a Substrait
extension (#33909)
24e5a580f9 is described below
commit 24e5a580f907a2a995fb9183cb4cce6218c711e6
Author: rtpsw <[email protected]>
AuthorDate: Fri Feb 10 22:14:43 2023 +0200
GH-33899: [C++] Add NamedTapRel relation as a Substrait extension (#33909)
See #33899. This PR adds `NamedTapRel` and a simple test case with a no-op
tap (i.e., just passing-through).
* Closes: #33899
Lead-authored-by: Yaron Gvili <[email protected]>
Co-authored-by: Weston Pace <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
---
cpp/proto/substrait/extension_rels.proto | 14 +++
cpp/src/arrow/engine/substrait/options.cc | 84 ++++++++++++++--
cpp/src/arrow/engine/substrait/options.h | 17 +++-
.../arrow/engine/substrait/relation_internal.cc | 6 +-
cpp/src/arrow/engine/substrait/serde_test.cc | 112 +++++++++++++++++++++
cpp/src/arrow/type.cc | 15 +++
cpp/src/arrow/type.h | 6 ++
7 files changed, 242 insertions(+), 12 deletions(-)
diff --git a/cpp/proto/substrait/extension_rels.proto
b/cpp/proto/substrait/extension_rels.proto
index 0392405f80..78c11b7d7e 100644
--- a/cpp/proto/substrait/extension_rels.proto
+++ b/cpp/proto/substrait/extension_rels.proto
@@ -44,3 +44,17 @@ message AsOfJoinRel {
repeated .substrait.Expression by = 2;
}
}
+
+// Named tap relation
+//
+// A tap is a relation having a single input relation that it passes through,
while also
+// causing some side-effect, e.g., writing to external storage.
+message NamedTapRel {
+ // The kind of tap
+ string kind = 1;
+ // A name used to configure the tap, e.g., a URI defining the destination of
writing
+ string name = 2;
+ // Column names for the tap's output. If specified there must be one name
per field.
+ // If empty, field names will be automatically generated.
+ repeated string columns = 3;
+}
diff --git a/cpp/src/arrow/engine/substrait/options.cc
b/cpp/src/arrow/engine/substrait/options.cc
index be23ce1e64..b4b10a021d 100644
--- a/cpp/src/arrow/engine/substrait/options.cc
+++ b/cpp/src/arrow/engine/substrait/options.cc
@@ -30,23 +30,39 @@
namespace arrow {
namespace engine {
+namespace {
+
+std::vector<compute::Declaration::Input> MakeDeclarationInputs(
+ const std::vector<DeclarationInfo>& inputs) {
+ std::vector<compute::Declaration::Input> input_decls(inputs.size());
+ for (size_t i = 0; i < inputs.size(); i++) {
+ input_decls[i] = inputs[i].declaration;
+ }
+ return input_decls;
+}
+
+} // namespace
+
class BaseExtensionProvider : public ExtensionProvider {
public:
- Result<RelationInfo> MakeRel(const std::vector<DeclarationInfo>& inputs,
+ Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
+ const std::vector<DeclarationInfo>& inputs,
const ExtensionDetails& ext_details,
const ExtensionSet& ext_set) override {
auto details = dynamic_cast<const DefaultExtensionDetails&>(ext_details);
- return MakeRel(inputs, details.rel, ext_set);
+ return MakeRel(conv_opts, inputs, details.rel, ext_set);
}
- virtual Result<RelationInfo> MakeRel(const std::vector<DeclarationInfo>&
inputs,
+ virtual Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
+ const std::vector<DeclarationInfo>&
inputs,
const google::protobuf::Any& rel,
const ExtensionSet& ext_set) = 0;
};
class DefaultExtensionProvider : public BaseExtensionProvider {
public:
- Result<RelationInfo> MakeRel(const std::vector<DeclarationInfo>& inputs,
+ Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
+ const std::vector<DeclarationInfo>& inputs,
const google::protobuf::Any& rel,
const ExtensionSet& ext_set) override {
if (rel.Is<substrait_ext::AsOfJoinRel>()) {
@@ -54,6 +70,11 @@ class DefaultExtensionProvider : public
BaseExtensionProvider {
rel.UnpackTo(&as_of_join_rel);
return MakeAsOfJoinRel(inputs, as_of_join_rel, ext_set);
}
+ if (rel.Is<substrait_ext::NamedTapRel>()) {
+ substrait_ext::NamedTapRel named_tap_rel;
+ rel.UnpackTo(&named_tap_rel);
+ return MakeNamedTapRel(conv_opts, inputs, named_tap_rel, ext_set);
+ }
return Status::NotImplemented("Unrecognized extension in Susbstrait plan:
",
rel.DebugString());
}
@@ -113,15 +134,38 @@ class DefaultExtensionProvider : public
BaseExtensionProvider {
compute::AsofJoinNodeOptions asofjoin_node_opts{std::move(input_keys),
tolerance};
// declaration
- std::vector<compute::Declaration::Input> input_decls(inputs.size());
- for (size_t i = 0; i < inputs.size(); i++) {
- input_decls[i] = inputs[i].declaration;
- }
+ auto input_decls = MakeDeclarationInputs(inputs);
return RelationInfo{
{compute::Declaration("asofjoin", input_decls,
std::move(asofjoin_node_opts)),
std::move(schema)},
std::move(field_output_indices)};
}
+
+ Result<RelationInfo> MakeNamedTapRel(const ConversionOptions& conv_opts,
+ const std::vector<DeclarationInfo>&
inputs,
+ const substrait_ext::NamedTapRel&
named_tap_rel,
+ const ExtensionSet& ext_set) {
+ if (inputs.size() != 1) {
+ return Status::Invalid(
+ "substrait_ext::NamedTapRel requires a single input but got: ",
inputs.size());
+ }
+
+ auto schema = inputs[0].output_schema;
+ int num_fields = schema->num_fields();
+ if (named_tap_rel.columns_size() != num_fields) {
+ return Status::Invalid("Got ", named_tap_rel.columns_size(),
+ " NamedTapRel columns but expected ", num_fields);
+ }
+ std::vector<std::string> columns(named_tap_rel.columns().begin(),
+ named_tap_rel.columns().end());
+ ARROW_ASSIGN_OR_RAISE(auto renamed_schema, schema->WithNames(columns));
+ auto input_decls = MakeDeclarationInputs(inputs);
+ ARROW_ASSIGN_OR_RAISE(
+ auto decl,
+ conv_opts.named_tap_provider(named_tap_rel.kind(), input_decls,
+ named_tap_rel.name(),
std::move(renamed_schema)));
+ return RelationInfo{{std::move(decl), std::move(renamed_schema)},
std::nullopt};
+ }
};
namespace {
@@ -143,5 +187,29 @@ void set_default_extension_provider(const
std::shared_ptr<ExtensionProvider>& pr
g_default_extension_provider = provider;
}
+namespace {
+
+NamedTapProvider g_default_named_tap_provider =
+ [](const std::string& tap_kind, std::vector<compute::Declaration::Input>
inputs,
+ const std::string& tap_name,
+ std::shared_ptr<Schema> tap_schema) -> Result<compute::Declaration> {
+ return Status::NotImplemented(
+ "Plan contained a NamedTapRel but no provider configured");
+};
+
+std::mutex g_default_named_tap_provider_mutex;
+
+} // namespace
+
+NamedTapProvider default_named_tap_provider() {
+ std::unique_lock<std::mutex> lock(g_default_named_tap_provider_mutex);
+ return g_default_named_tap_provider;
+}
+
+void set_default_named_tap_provider(NamedTapProvider provider) {
+ std::unique_lock<std::mutex> lock(g_default_named_tap_provider_mutex);
+ g_default_named_tap_provider = provider;
+}
+
} // namespace engine
} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/options.h
b/cpp/src/arrow/engine/substrait/options.h
index 35a4f70aa9..3b4a6963ac 100644
--- a/cpp/src/arrow/engine/substrait/options.h
+++ b/cpp/src/arrow/engine/substrait/options.h
@@ -23,6 +23,8 @@
#include <string>
#include <vector>
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/options.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/engine/substrait/type_fwd.h"
#include "arrow/engine/substrait/visibility.h"
@@ -67,6 +69,10 @@ using NamedTableProvider =
std::function<Result<compute::Declaration>(const
std::vector<std::string>&)>;
static NamedTableProvider kDefaultNamedTableProvider;
+using NamedTapProvider = std::function<Result<compute::Declaration>(
+ const std::string&, std::vector<compute::Declaration::Input>, const
std::string&,
+ std::shared_ptr<Schema>)>;
+
class ARROW_ENGINE_EXPORT ExtensionDetails {
public:
virtual ~ExtensionDetails() = default;
@@ -75,7 +81,8 @@ class ARROW_ENGINE_EXPORT ExtensionDetails {
class ARROW_ENGINE_EXPORT ExtensionProvider {
public:
virtual ~ExtensionProvider() = default;
- virtual Result<RelationInfo> MakeRel(const std::vector<DeclarationInfo>&
inputs,
+ virtual Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
+ const std::vector<DeclarationInfo>&
inputs,
const ExtensionDetails& ext_details,
const ExtensionSet& ext_set) = 0;
};
@@ -88,6 +95,10 @@ ARROW_ENGINE_EXPORT std::shared_ptr<ExtensionProvider>
default_extension_provide
ARROW_ENGINE_EXPORT void set_default_extension_provider(
const std::shared_ptr<ExtensionProvider>& provider);
+ARROW_ENGINE_EXPORT NamedTapProvider default_named_tap_provider();
+
+ARROW_ENGINE_EXPORT void set_default_named_tap_provider(NamedTapProvider
provider);
+
/// Options that control the conversion between Substrait and Acero
representations of a
/// plan.
struct ARROW_ENGINE_EXPORT ConversionOptions {
@@ -98,6 +109,10 @@ struct ARROW_ENGINE_EXPORT ConversionOptions {
/// The default behavior will return an invalid status if the plan has any
/// named table relations.
NamedTableProvider named_table_provider = kDefaultNamedTableProvider;
+ /// \brief A custom strategy to be used for obtaining a tap declaration
+ ///
+ /// The default provider returns an error
+ NamedTapProvider named_tap_provider = default_named_tap_provider();
std::shared_ptr<ExtensionProvider> extension_provider =
default_extension_provider();
};
diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc
b/cpp/src/arrow/engine/substrait/relation_internal.cc
index 4fb7bb2a78..19a38cd40e 100644
--- a/cpp/src/arrow/engine/substrait/relation_internal.cc
+++ b/cpp/src/arrow/engine/substrait/relation_internal.cc
@@ -206,7 +206,7 @@ Result<RelationInfo> GetExtensionRelationInfo(const
substrait::Rel& rel,
case substrait::Rel::RelTypeCase::kExtensionLeaf: {
const auto& ext = rel.extension_leaf();
DefaultExtensionDetails detail{ext.detail()};
- return conv_opts.extension_provider->MakeRel(inputs, detail, ext_set);
+ return conv_opts.extension_provider->MakeRel(conv_opts, inputs, detail,
ext_set);
}
case substrait::Rel::RelTypeCase::kExtensionSingle: {
@@ -215,7 +215,7 @@ Result<RelationInfo> GetExtensionRelationInfo(const
substrait::Rel& rel,
FromProto(ext.input(), ext_set, conv_opts));
inputs.push_back(std::move(input_info));
DefaultExtensionDetails detail{ext.detail()};
- return conv_opts.extension_provider->MakeRel(inputs, detail, ext_set);
+ return conv_opts.extension_provider->MakeRel(conv_opts, inputs, detail,
ext_set);
}
case substrait::Rel::RelTypeCase::kExtensionMulti: {
@@ -225,7 +225,7 @@ Result<RelationInfo> GetExtensionRelationInfo(const
substrait::Rel& rel,
inputs.push_back(std::move(input_info));
}
DefaultExtensionDetails detail{ext.detail()};
- return conv_opts.extension_provider->MakeRel(inputs, detail, ext_set);
+ return conv_opts.extension_provider->MakeRel(conv_opts, inputs, detail,
ext_set);
}
default: {
diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc
b/cpp/src/arrow/engine/substrait/serde_test.cc
index d99d6fa0c4..97a2aea398 100644
--- a/cpp/src/arrow/engine/substrait/serde_test.cc
+++ b/cpp/src/arrow/engine/substrait/serde_test.cc
@@ -36,6 +36,7 @@
#include "arrow/compute/exec/exec_plan.h"
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/expression_internal.h"
+#include "arrow/compute/exec/map_node.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/exec/test_util.h"
#include "arrow/compute/exec/util.h"
@@ -88,6 +89,30 @@ using internal::checked_cast;
using internal::hash_combine;
namespace engine {
+Status AddPassFactory(
+ const std::string& factory_name,
+ compute::ExecFactoryRegistry* registry =
compute::default_exec_factory_registry()) {
+ using compute::ExecBatch;
+ using compute::ExecNode;
+ using compute::ExecNodeOptions;
+ using compute::ExecPlan;
+ struct PassNode : public compute::MapNode {
+ static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*>
inputs,
+ const compute::ExecNodeOptions& options) {
+ RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "PassNode"));
+ return plan->EmplaceNode<PassNode>(plan, inputs,
inputs[0]->output_schema());
+ }
+
+ PassNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
+ std::shared_ptr<Schema> output_schema)
+ : MapNode(plan, inputs, output_schema) {}
+
+ const char* kind_name() const override { return "PassNode"; }
+ Result<ExecBatch> ProcessBatch(ExecBatch batch) override { return batch; }
+ };
+ return registry->AddFactory(factory_name, PassNode::Make);
+}
+
const auto kNullConsumer = std::make_shared<compute::NullSinkNodeConsumer>();
void WriteIpcData(const std::string& path,
@@ -5355,5 +5380,92 @@ TEST(Substrait, AsOfJoinDefaultEmit) {
CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options);
}
+TEST(Substrait, PlanWithNamedTapExtension) {
+ // This demos an extension relation
+ std::string substrait_json = R"({
+ "extensionUris": [],
+ "extensions": [],
+ "relations": [{
+ "root": {
+ "input": {
+ "extension_multi": {
+ "inputs": [
+ {
+ "read": {
+ "common": {
+ "direct": {
+ }
+ },
+ "baseSchema": {
+ "names": ["time", "key", "value"],
+ "struct": {
+ "types": [
+ {
+ "i32": {
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ {
+ "i32": {
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ {
+ "fp64": {
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ }
+ ],
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ "namedTable": {
+ "names": ["T"]
+ }
+ }
+ }
+ ],
+ "detail": {
+ "@type": "/arrow.substrait_ext.NamedTapRel",
+ "kind" : "pass_for_named_tap",
+ "name" : "does_not_matter",
+ "columns" : ["pass_time", "pass_key", "pass_value"]
+ }
+ }
+ },
+ "names": ["t", "k", "v"]
+ }
+ }],
+ "expectedTypeUrls": []
+ })";
+
+ ASSERT_OK(AddPassFactory("pass_for_named_tap"));
+
+ std::shared_ptr<Schema> input_schema =
+ schema({field("time", int32()), field("key", int32()), field("value",
float64())});
+ NamedTableProvider table_provider = AlwaysProvideSameTable(
+ TableFromJSON(input_schema, {"[[2, 1, 1.1], [4, 1, 2.1], [6, 2,
3.1]]"}));
+ ConversionOptions conversion_options;
+ conversion_options.named_table_provider = std::move(table_provider);
+ conversion_options.named_tap_provider =
+ [](const std::string& tap_kind, std::vector<compute::Declaration::Input>
inputs,
+ const std::string& tap_name,
+ std::shared_ptr<Schema> tap_schema) -> Result<compute::Declaration> {
+ return compute::Declaration{tap_kind, std::move(inputs),
compute::ExecNodeOptions{}};
+ };
+
+ ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan",
substrait_json));
+
+ std::shared_ptr<Schema> output_schema =
+ schema({field("t", int32()), field("k", int32()), field("v",
float64())});
+ auto expected_table =
+ TableFromJSON(output_schema, {"[[2, 1, 1.1], [4, 1, 2.1], [6, 2,
3.1]]"});
+ CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options);
+}
+
} // namespace engine
} // namespace arrow
diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index 825a68e68c..0e3732db6e 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -1714,6 +1714,21 @@ bool Schema::HasDistinctFieldNames() const {
return names.size() == fields.size();
}
+Result<std::shared_ptr<Schema>> Schema::WithNames(
+ const std::vector<std::string>& names) const {
+ if (names.size() != impl_->fields_.size()) {
+ return Status::Invalid("attempted to rename schema with ",
impl_->fields_.size(),
+ " fields but only ", names.size(), " new names were
given");
+ }
+ FieldVector new_fields;
+ new_fields.reserve(names.size());
+ auto names_itr = names.begin();
+ for (const auto& field : impl_->fields_) {
+ new_fields.push_back(field->WithName(*names_itr++));
+ }
+ return schema(std::move(new_fields));
+}
+
std::shared_ptr<Schema> Schema::WithMetadata(
const std::shared_ptr<const KeyValueMetadata>& metadata) const {
return std::make_shared<Schema>(impl_->fields_, metadata);
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index cf58218a7e..4ea4796231 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -1968,6 +1968,12 @@ class ARROW_EXPORT Schema : public
detail::Fingerprintable,
Result<std::shared_ptr<Schema>> SetField(int i,
const std::shared_ptr<Field>&
field) const;
+ /// \brief Replace field names with new names
+ ///
+ /// \param[in] names new names
+ /// \return new Schema
+ Result<std::shared_ptr<Schema>> WithNames(const std::vector<std::string>&
names) const;
+
/// \brief Replace key-value metadata with new metadata
///
/// \param[in] metadata new KeyValueMetadata