icexelloss commented on code in PR #14385:
URL: https://github.com/apache/arrow/pull/14385#discussion_r994736304
##########
cpp/src/arrow/engine/substrait/serde_test.cc:
##########
@@ -3187,5 +3198,164 @@ TEST(Substrait, IsthmusPlan) {
*compute::default_exec_context(), buf, {},
conversion_options);
}
+TEST(Substrait, PlanWithExtension) {
+ // This demos an extension relation
+ std::string substrait_json = R"({
+ "extensionUris": [],
+ "extensions": [],
+ "relations": [{
+ "root": {
+ "input": {
+ "extension_multi": {
+ "common": {
+ "emit": {
+ "outputMapping": [0, 1, 2, 3]
+ }
+ },
+ "inputs": [
+ {
+ "read": {
+ "common": {
+ "direct": {
+ }
+ },
+ "baseSchema": {
+ "names": ["time", "key", "value1"],
+ "struct": {
+ "types": [
+ {
+ "i32": {
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ {
+ "i32": {
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ {
+ "fp64": {
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ }
+ ],
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ "namedTable": {
+ "names": ["T1"]
+ }
+ }
+ },
+ {
+ "read": {
+ "common": {
+ "direct": {
+ }
+ },
+ "baseSchema": {
+ "names": ["time", "key", "value2"],
+ "struct": {
+ "types": [
+ {
+ "i32": {
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ {
+ "i32": {
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ },
+ {
+ "fp64": {
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_NULLABLE"
+ }
+ }
+ ],
+ "typeVariationReference": 0,
+ "nullability": "NULLABILITY_REQUIRED"
+ }
+ },
+ "namedTable": {
+ "names": ["T2"]
+ }
+ }
+ }
+ ],
+ "detail": {
+ "@type": "/arrow.substrait.AsOfJoinRel",
+ "on": {
+ "selection": {
+ "directReference": {
+ "structField": {
+ "field": 0,
+ }
+ },
+ "rootReference": {}
+ }
+ },
+ "by": [
Review Comment:
Shouldn't this be a list of field refs - one for each table in the join?
##########
cpp/src/arrow/compute/exec/asof_join_node.cc:
##########
@@ -951,28 +953,27 @@ class AsofJoinNode : public ExecNode {
}
static arrow::Result<std::shared_ptr<Schema>> MakeOutputSchema(
- const std::vector<ExecNode*>& inputs,
+ const std::vector<std::shared_ptr<Schema>> input_schema,
Review Comment:
kk sounds reasonable
##########
cpp/src/arrow/compute/exec/asof_join_node.cc:
##########
@@ -951,28 +953,27 @@ class AsofJoinNode : public ExecNode {
}
static arrow::Result<std::shared_ptr<Schema>> MakeOutputSchema(
- const std::vector<ExecNode*>& inputs,
+ const std::vector<std::shared_ptr<Schema>> input_schema,
const std::vector<col_index_t>& indices_of_on_key,
const std::vector<std::vector<col_index_t>>& indices_of_by_key) {
std::vector<std::shared_ptr<arrow::Field>> fields;
- size_t n_by = indices_of_by_key[0].size();
+ size_t n_by = indices_of_by_key.size() == 0 ? 0 :
indices_of_by_key[0].size();
Review Comment:
When can `indices_of_by_key.size() == 0` means this is a join without `by`
key?
##########
cpp/src/arrow/compute/exec/asof_join_node.cc:
##########
@@ -1030,6 +1031,32 @@ class AsofJoinNode : public ExecNode {
return match.indices()[0];
}
+ static Result<std::vector<col_index_t>> GetIndicesOfOnKey(
+ const std::vector<std::shared_ptr<Schema>>& input_schema, const
FieldRef& on_key) {
+ size_t n_input = input_schema.size();
+ std::vector<col_index_t> indices_of_on_key(n_input);
+ for (size_t i = 0; i < n_input; ++i) {
+ ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i],
+ FindColIndex(*input_schema[i], on_key, "on"));
Review Comment:
I am not sure we need this function - in practice this substrait field ref
should always be index based..
##########
cpp/src/arrow/compute/exec/asof_join_node.cc:
##########
@@ -1030,6 +1031,32 @@ class AsofJoinNode : public ExecNode {
return match.indices()[0];
}
+ static Result<std::vector<col_index_t>> GetIndicesOfOnKey(
+ const std::vector<std::shared_ptr<Schema>>& input_schema, const
FieldRef& on_key) {
+ size_t n_input = input_schema.size();
+ std::vector<col_index_t> indices_of_on_key(n_input);
+ for (size_t i = 0; i < n_input; ++i) {
+ ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i],
+ FindColIndex(*input_schema[i], on_key, "on"));
Review Comment:
I would prefer if we ever support named-based field ref in substrait (which
I doubt we will ever do), then let's add this function then
##########
cpp/proto/substrait/extension_rels.proto:
##########
@@ -0,0 +1,32 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+syntax = "proto3";
+
+package arrow.substrait;
+
+import "substrait/algebra.proto";
+
+option csharp_namespace = "Arrow.Substrait";
+option go_package = "github.com/apache/arrow/substrait";
+option java_multiple_files = true;
+option java_package = "io.arrow.substrait";
+
+message AsOfJoinRel {
+ .substrait.Expression on = 1;
Review Comment:
We are using `StructField` for this so I don't think we can use name for
this, right?
I think it probably makes more sense to have this only be a repeated field
in case that `on` column doesn't share the same column index for each table
##########
cpp/src/arrow/engine/substrait/options.h:
##########
@@ -32,7 +36,7 @@ namespace engine {
/// How strictly to adhere to the input structure when converting between
Substrait and
/// Acero representations of a plan. This allows the user to trade conversion
accuracy
/// for performance and lenience.
-enum class ConversionStrictness {
+enum class ARROW_ENGINE_EXPORT ConversionStrictness {
Review Comment:
What does the macro `ARROW_ENGINE_EXPORT` do?
##########
cpp/src/arrow/engine/substrait/options.cc:
##########
@@ -0,0 +1,99 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#include <iostream>
+
+#include "arrow/engine/substrait/options.h"
+
+#include <google/protobuf/util/json_util.h>
+#include "arrow/compute/exec/asof_join_node.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/engine/substrait/expression_internal.h"
+#include "arrow/engine/substrait/relation_internal.h"
+#include "substrait/extension_rels.pb.h"
+
+namespace arrow {
+namespace engine {
+
+class DefaultExtensionProvider : public ExtensionProvider {
+ public:
+ Result<DeclarationInfo> MakeRel(const std::vector<DeclarationInfo>& inputs,
+ const google::protobuf::Any& rel,
+ const ExtensionSet& ext_set) override {
+ if (rel.Is<arrow::substrait::AsOfJoinRel>()) {
+ arrow::substrait::AsOfJoinRel as_of_join_rel;
+ rel.UnpackTo(&as_of_join_rel);
+ return MakeAsOfJoinRel(inputs, as_of_join_rel, ext_set);
+ }
+ return Status::NotImplemented("Unrecognized extension in Susbstrait plan:
",
+ rel.DebugString());
+ }
+
+ private:
+ Result<DeclarationInfo> MakeAsOfJoinRel(
+ const std::vector<DeclarationInfo>& inputs,
+ const arrow::substrait::AsOfJoinRel& as_of_join_rel, const ExtensionSet&
ext_set) {
+ if (inputs.size() < 2) {
+ return Status::Invalid("substrait::AsOfJoinNode too few input tables: ",
+ inputs.size());
+ }
+ // on-key
+ if (!as_of_join_rel.has_on()) {
+ return Status::Invalid("substrait::AsOfJoinNode missing on-key");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto on_key_expr, FromProto(as_of_join_rel.on(),
ext_set, {}));
+ if (on_key_expr.field_ref() == NULLPTR) {
+ return Status::NotImplemented("substrait::AsOfJoinNode non-field-ref
on-key");
+ }
+ const FieldRef& on_key = *on_key_expr.field_ref();
+
+ // by-key
+ std::vector<FieldRef> by_key;
Review Comment:
I think `by_keys` here is a little clear since this is a vector
##########
cpp/src/arrow/engine/substrait/options.h:
##########
@@ -65,16 +69,27 @@ using NamedTableProvider =
std::function<Result<compute::Declaration>(const
std::vector<std::string>&)>;
static NamedTableProvider kDefaultNamedTableProvider;
+class ARROW_ENGINE_EXPORT ExtensionProvider {
+ public:
+ static std::shared_ptr<ExtensionProvider> kDefaultExtensionProvider;
Review Comment:
Should we define this the same way that `kDefaultNamedTableProvider` is
defined? I don't see a reason why those two should be different
##########
cpp/src/arrow/compute/exec/asof_join_node.cc:
##########
@@ -1030,6 +1031,32 @@ class AsofJoinNode : public ExecNode {
return match.indices()[0];
}
+ static Result<std::vector<col_index_t>> GetIndicesOfOnKey(
+ const std::vector<std::shared_ptr<Schema>>& input_schema, const
FieldRef& on_key) {
+ size_t n_input = input_schema.size();
+ std::vector<col_index_t> indices_of_on_key(n_input);
+ for (size_t i = 0; i < n_input; ++i) {
+ ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i],
+ FindColIndex(*input_schema[i], on_key, "on"));
+ }
+ return indices_of_on_key;
+ }
+
+ static Result<std::vector<std::vector<col_index_t>>> GetIndicesOfByKey(
Review Comment:
Same comment as `GetIndicesOfByKey`
##########
cpp/src/arrow/engine/substrait/options.cc:
##########
@@ -0,0 +1,99 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#include <iostream>
+
+#include "arrow/engine/substrait/options.h"
+
+#include <google/protobuf/util/json_util.h>
+#include "arrow/compute/exec/asof_join_node.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/engine/substrait/expression_internal.h"
+#include "arrow/engine/substrait/relation_internal.h"
+#include "substrait/extension_rels.pb.h"
+
+namespace arrow {
+namespace engine {
+
+class DefaultExtensionProvider : public ExtensionProvider {
+ public:
+ Result<DeclarationInfo> MakeRel(const std::vector<DeclarationInfo>& inputs,
+ const google::protobuf::Any& rel,
+ const ExtensionSet& ext_set) override {
+ if (rel.Is<arrow::substrait::AsOfJoinRel>()) {
+ arrow::substrait::AsOfJoinRel as_of_join_rel;
+ rel.UnpackTo(&as_of_join_rel);
+ return MakeAsOfJoinRel(inputs, as_of_join_rel, ext_set);
+ }
+ return Status::NotImplemented("Unrecognized extension in Susbstrait plan:
",
+ rel.DebugString());
+ }
+
+ private:
+ Result<DeclarationInfo> MakeAsOfJoinRel(
+ const std::vector<DeclarationInfo>& inputs,
+ const arrow::substrait::AsOfJoinRel& as_of_join_rel, const ExtensionSet&
ext_set) {
+ if (inputs.size() < 2) {
+ return Status::Invalid("substrait::AsOfJoinNode too few input tables: ",
+ inputs.size());
+ }
+ // on-key
+ if (!as_of_join_rel.has_on()) {
+ return Status::Invalid("substrait::AsOfJoinNode missing on-key");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto on_key_expr, FromProto(as_of_join_rel.on(),
ext_set, {}));
+ if (on_key_expr.field_ref() == NULLPTR) {
+ return Status::NotImplemented("substrait::AsOfJoinNode non-field-ref
on-key");
+ }
+ const FieldRef& on_key = *on_key_expr.field_ref();
+
+ // by-key
+ std::vector<FieldRef> by_key;
+ for (const auto& by_item : as_of_join_rel.by()) {
+ ARROW_ASSIGN_OR_RAISE(auto by_key_expr, FromProto(by_item, ext_set, {}));
+ if (by_key_expr.field_ref() == NULLPTR) {
+ return Status::NotImplemented("substrait::AsOfJoinNode non-field-ref
by-key");
+ }
+ by_key.push_back(*by_key_expr.field_ref());
+ }
+
+ // schema
+ int64_t tolerance = as_of_join_rel.tolerance();
+ std::vector<std::shared_ptr<Schema>> input_schema(inputs.size());
+ for (size_t i = 0; i < inputs.size(); i++) {
+ input_schema[i] = inputs[i].output_schema;
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ auto schema, compute::asofjoin::MakeOutputSchema(input_schema, on_key,
by_key));
+ compute::AsofJoinNodeOptions asofjoin_node_opts{std::move(on_key),
std::move(by_key),
+ tolerance};
+
+ // declaration
+ std::vector<compute::Declaration::Input> input_decls(inputs.size());
+ for (size_t i = 0; i < inputs.size(); i++) {
+ input_decls[i] = inputs[i].declaration;
+ }
+ return DeclarationInfo{
+ compute::Declaration("asofjoin", input_decls,
std::move(asofjoin_node_opts)),
+ std::move(schema)};
+ }
+};
+
+std::shared_ptr<ExtensionProvider>
ExtensionProvider::kDefaultExtensionProvider =
Review Comment:
Nvm this, I see the declaration there in options.h
##########
cpp/src/arrow/engine/substrait/options.cc:
##########
@@ -0,0 +1,99 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#include <iostream>
+
+#include "arrow/engine/substrait/options.h"
+
+#include <google/protobuf/util/json_util.h>
+#include "arrow/compute/exec/asof_join_node.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/engine/substrait/expression_internal.h"
+#include "arrow/engine/substrait/relation_internal.h"
+#include "substrait/extension_rels.pb.h"
+
+namespace arrow {
+namespace engine {
+
+class DefaultExtensionProvider : public ExtensionProvider {
+ public:
+ Result<DeclarationInfo> MakeRel(const std::vector<DeclarationInfo>& inputs,
+ const google::protobuf::Any& rel,
+ const ExtensionSet& ext_set) override {
+ if (rel.Is<arrow::substrait::AsOfJoinRel>()) {
+ arrow::substrait::AsOfJoinRel as_of_join_rel;
+ rel.UnpackTo(&as_of_join_rel);
+ return MakeAsOfJoinRel(inputs, as_of_join_rel, ext_set);
+ }
+ return Status::NotImplemented("Unrecognized extension in Susbstrait plan:
",
+ rel.DebugString());
+ }
+
+ private:
+ Result<DeclarationInfo> MakeAsOfJoinRel(
+ const std::vector<DeclarationInfo>& inputs,
+ const arrow::substrait::AsOfJoinRel& as_of_join_rel, const ExtensionSet&
ext_set) {
+ if (inputs.size() < 2) {
+ return Status::Invalid("substrait::AsOfJoinNode too few input tables: ",
+ inputs.size());
+ }
+ // on-key
+ if (!as_of_join_rel.has_on()) {
+ return Status::Invalid("substrait::AsOfJoinNode missing on-key");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto on_key_expr, FromProto(as_of_join_rel.on(),
ext_set, {}));
+ if (on_key_expr.field_ref() == NULLPTR) {
+ return Status::NotImplemented("substrait::AsOfJoinNode non-field-ref
on-key");
+ }
+ const FieldRef& on_key = *on_key_expr.field_ref();
+
+ // by-key
+ std::vector<FieldRef> by_key;
+ for (const auto& by_item : as_of_join_rel.by()) {
+ ARROW_ASSIGN_OR_RAISE(auto by_key_expr, FromProto(by_item, ext_set, {}));
+ if (by_key_expr.field_ref() == NULLPTR) {
+ return Status::NotImplemented("substrait::AsOfJoinNode non-field-ref
by-key");
+ }
+ by_key.push_back(*by_key_expr.field_ref());
+ }
+
+ // schema
+ int64_t tolerance = as_of_join_rel.tolerance();
+ std::vector<std::shared_ptr<Schema>> input_schema(inputs.size());
+ for (size_t i = 0; i < inputs.size(); i++) {
+ input_schema[i] = inputs[i].output_schema;
+ }
+ ARROW_ASSIGN_OR_RAISE(
+ auto schema, compute::asofjoin::MakeOutputSchema(input_schema, on_key,
by_key));
+ compute::AsofJoinNodeOptions asofjoin_node_opts{std::move(on_key),
std::move(by_key),
+ tolerance};
+
+ // declaration
+ std::vector<compute::Declaration::Input> input_decls(inputs.size());
+ for (size_t i = 0; i < inputs.size(); i++) {
+ input_decls[i] = inputs[i].declaration;
+ }
+ return DeclarationInfo{
+ compute::Declaration("asofjoin", input_decls,
std::move(asofjoin_node_opts)),
+ std::move(schema)};
+ }
+};
+
+std::shared_ptr<ExtensionProvider>
ExtensionProvider::kDefaultExtensionProvider =
Review Comment:
Should we move this to substrait/options.h file? (Since the
`kDefaultNamedTableProvider` is also there)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]