westonpace commented on code in PR #13285:
URL: https://github.com/apache/arrow/pull/13285#discussion_r887579756
##########
cpp/src/arrow/engine/substrait/expression_internal.cc:
##########
@@ -159,21 +159,11 @@ Result<compute::Expression> FromProto(const
substrait::Expression& expr,
ARROW_ASSIGN_OR_RAISE(auto decoded_function,
ext_set.DecodeFunction(scalar_fn.function_reference()));
+
+ auto arrow_function =
ext_set.GetFunctionMap().GetArrowFromSubstrait(decoded_function.name.to_string());
- std::vector<compute::Expression> arguments(scalar_fn.args_size());
- for (int i = 0; i < scalar_fn.args_size(); ++i) {
- ARROW_ASSIGN_OR_RAISE(arguments[i], FromProto(scalar_fn.args(i),
ext_set));
- }
+ return arrow_function(scalar_fn);
Review Comment:
What happens if `arrow_function` is not found?
##########
cpp/src/arrow/engine/substrait/extension_set.cc:
##########
@@ -351,5 +355,278 @@ ExtensionIdRegistry* default_extension_id_registry() {
return &impl_;
}
+Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name,
ArrowToSubstrait conversion_func){
+ if (arrow_to_substrait.find(arrow_function_name) !=
arrow_to_substrait.end()){
+ arrow_to_substrait[arrow_function_name] = conversion_func;
+ }
+ return Status::OK();
+}
+
+Status FunctionMapping::AddSubstraitToArrow(std::string
substrait_function_name, SubstraitToArrow conversion_func){
+ if (substrait_to_arrow.find(substrait_function_name) !=
substrait_to_arrow.end()){
+ substrait_to_arrow[substrait_function_name] = conversion_func;
+ }
+ return Status::OK();
+}
+
+SubstraitToArrow substrait_add_to_arrow = [] (const
substrait::Expression::ScalarFunction& call) ->
Result<arrow::compute::Expression> {
+ auto value_1 = call.args(1);
+ auto value_2 = call.args(2);
+ ExtensionSet ext_set_;
+ ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
+ ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_));
+ auto options = call.args(0);
+ if (options.has_enum_()) {
+ auto overflow_handling = options.enum_();
+ if(overflow_handling.has_specified()){
+ std::string overflow_type = overflow_handling.specified();
+ if(overflow_type == "SILENT"){
+ return arrow::compute::call("add", {expression_1,expression_2},
compute::ArithmeticOptions());
+ } else if (overflow_type == "SATURATE") {
+ return Status::Invalid("Arrow does not support a saturating add");
+ } else {
+ return arrow::compute::call("add_checked", {expression_1,expression_2},
compute::ArithmeticOptions(true));
+ }
+ } else {
+ return arrow::compute::call("add", {expression_1,expression_2},
compute::ArithmeticOptions());
+ }
+ } else {
+ return Status::Invalid("Substrait Function Options should be an enum");
+ }
+};
+
+ArrowToSubstrait arrow_add_to_substrait = [] (const
arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set_)
-> Result<substrait::Expression::ScalarFunction> {
+ substrait::Expression::ScalarFunction substrait_call;
+
+ ARROW_ASSIGN_OR_RAISE(auto function_reference,
ext_set_->EncodeFunction("add"));
+ substrait_call.set_function_reference(function_reference);
+
+ substrait::Expression::Enum options;
+ std::string overflow_handling = "ERROR";
+ options.set_specified(overflow_handling);
+ substrait_call.add_args()->set_allocated_enum_(&options);
+
+ auto expression_1 = call.arguments[0];
+ auto expression_2 = call.arguments[1];
+
+ ARROW_ASSIGN_OR_RAISE(auto value_1, ToProto(expression_1, ext_set_));
+ ARROW_ASSIGN_OR_RAISE(auto value_2, ToProto(expression_2, ext_set_));
+
+ substrait_call.add_args()->CopyFrom(*value_1);
+ substrait_call.add_args()->CopyFrom(*value_2);
+ return std::move(substrait_call);
+};
+
+ArrowToSubstrait arrow_unchecked_add_to_substrait = [] (const
arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) ->
Result<substrait::Expression::ScalarFunction> {
+ substrait::Expression::ScalarFunction substrait_call;
+
+ ARROW_ASSIGN_OR_RAISE(auto function_reference,
ext_set_->EncodeFunction("add"));
+ substrait_call.set_function_reference(function_reference);
+
+ substrait::Expression::Enum options;
+ std::string overflow_handling = "SILENT";
+ options.set_specified(overflow_handling);
+ substrait_call.add_args()->set_allocated_enum_(&options);
+
+ auto expression_1 = call.arguments[0];
+ auto expression_2 = call.arguments[1];
+
+ ARROW_ASSIGN_OR_RAISE(auto value_1, ToProto(expression_1, ext_set_));
+ ARROW_ASSIGN_OR_RAISE(auto value_2, ToProto(expression_2, ext_set_));
+
+ substrait_call.add_args()->CopyFrom(*value_1);
+ substrait_call.add_args()->CopyFrom(*value_2);
+ return std::move(substrait_call);
+};
+
+
+// Boolean Functions mapping
+SubstraitToArrow substrait_not_to_arrow = [] (const
substrait::Expression::ScalarFunction& call) ->
Result<arrow::compute::Expression> {
+ auto value_1 = call.args(1);
+ ExtensionSet ext_set_;
+ ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
+ return arrow::compute::call("invert", {expression_1});
+};
+
+SubstraitToArrow substrait_or_to_arrow = [] (const
substrait::Expression::ScalarFunction& call) ->
Result<arrow::compute::Expression> {
+ int num_args = call.args_size(); // OR function has variadic arguments
+ substrait::Expression value;
+ ExtensionSet ext_set_;
+ arrow::compute::Expression expression;
+ std::vector<arrow::compute::Expression> func_args;
+ for(int i=0; i<num_args; ++i){
+ value = call.args(i);
+ ARROW_ASSIGN_OR_RAISE(expression, FromProto(value, ext_set_));
+ func_args.push_back(expression);
+ }
+ return arrow::compute::call("or_kleene", func_args);
+};
+
+SubstraitToArrow substrait_and_to_arrow = [] (const
substrait::Expression::ScalarFunction& call) ->
Result<arrow::compute::Expression> {
+ int num_args = call.args_size(); // AND function has variadic arguments
+ substrait::Expression value;
+ ExtensionSet ext_set_;
+ arrow::compute::Expression expression;
+ std::vector<arrow::compute::Expression> func_args;
+ for(int i=0; i<num_args; ++i){
+ value = call.args(i);
+ ARROW_ASSIGN_OR_RAISE(expression, FromProto(value, ext_set_));
+ func_args.push_back(expression);
+ }
+ return arrow::compute::call("and_kleene", func_args);
+};
+
+SubstraitToArrow substrait_xor_to_arrow = [] (const
substrait::Expression::ScalarFunction& call) ->
Result<arrow::compute::Expression> {
+ auto value_1 = call.args(0);
+ auto value_2 = call.args(1);
+ ExtensionSet ext_set_;
+ ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
+ ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_));
+ return arrow::compute::call("xor", {expression_1, expression_2});
+};
+
+ArrowToSubstrait arrow_invert_to_substrait = [] (const
arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) ->
Result<substrait::Expression::ScalarFunction> {
+ substrait::Expression::ScalarFunction substrait_call;
+
+ ARROW_ASSIGN_OR_RAISE(auto function_reference,
ext_set_->EncodeFunction("not"));
+ substrait_call.set_function_reference(function_reference);
+
+ auto expression_1 = call.arguments[0];
+ auto expression_2 = call.arguments[1];
+
+ ARROW_ASSIGN_OR_RAISE(auto value_1, ToProto(expression_1, ext_set_));
+ ARROW_ASSIGN_OR_RAISE(auto value_2, ToProto(expression_2, ext_set_));
+
+ substrait_call.add_args()->CopyFrom(*value_1);
+ substrait_call.add_args()->CopyFrom(*value_2);
+ return std::move(substrait_call);
+
+};
Review Comment:
Several functions map exactly and fit this simple pattern. Can we create a
helper method for:
* Grab all arguments from Arrow call
* Convert all arguments to Substrait expressions
* Create Substrait call
Then all we should need to specify is the substrait function name.
##########
cpp/src/arrow/engine/substrait/extension_set.cc:
##########
@@ -351,5 +355,278 @@ ExtensionIdRegistry* default_extension_id_registry() {
return &impl_;
}
+Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name,
ArrowToSubstrait conversion_func){
+ if (arrow_to_substrait.find(arrow_function_name) !=
arrow_to_substrait.end()){
+ arrow_to_substrait[arrow_function_name] = conversion_func;
+ }
Review Comment:
Should we return an invalid status as an else clause here?
##########
cpp/src/arrow/engine/substrait/extension_set.cc:
##########
@@ -351,5 +355,278 @@ ExtensionIdRegistry* default_extension_id_registry() {
return &impl_;
}
+Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name,
ArrowToSubstrait conversion_func){
+ if (arrow_to_substrait.find(arrow_function_name) !=
arrow_to_substrait.end()){
+ arrow_to_substrait[arrow_function_name] = conversion_func;
+ }
+ return Status::OK();
+}
+
+Status FunctionMapping::AddSubstraitToArrow(std::string
substrait_function_name, SubstraitToArrow conversion_func){
+ if (substrait_to_arrow.find(substrait_function_name) !=
substrait_to_arrow.end()){
+ substrait_to_arrow[substrait_function_name] = conversion_func;
+ }
Review Comment:
Same, perhaps the else should be an invalid status.
##########
cpp/src/arrow/engine/substrait/extension_set.h:
##########
@@ -22,16 +22,55 @@
#include <unordered_map>
#include <vector>
+#include "arrow/compute/function.h"
+#include "arrow/compute/exec/expression.h"
#include "arrow/engine/substrait/visibility.h"
#include "arrow/type_fwd.h"
#include "arrow/util/optional.h"
#include "arrow/util/string_view.h"
#include "arrow/util/hash_util.h"
+#include "substrait/expression.pb.h" // IWYU pragma: export
namespace arrow {
namespace engine {
+class ExtensionSet;
+using ArrowToSubstrait =
std::function<Result<substrait::Expression::ScalarFunction>(const
arrow::compute::Expression::Call&, arrow::engine::ExtensionSet*)>;
+using SubstraitToArrow =
std::function<Result<arrow::compute::Expression>(const
substrait::Expression::ScalarFunction&)>;
+
+class FunctionMapping {
+
+ enum defined_functions {
Review Comment:
Is this used anywhere?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]