westonpace commented on code in PR #13285: URL: https://github.com/apache/arrow/pull/13285#discussion_r899666451
########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) != substrait_to_arrow.end()){ Review Comment: Same as above. This seems backwards (but maybe I'm just not thinking right) ########## cpp/src/arrow/engine/substrait/expression_internal.cc: ########## @@ -159,21 +160,17 @@ Result<compute::Expression> FromProto(const substrait::Expression& expr, ARROW_ASSIGN_OR_RAISE(auto decoded_function, ext_set.DecodeFunction(scalar_fn.function_reference())); + ARROW_ASSIGN_OR_RAISE(auto arrow_function, ext_set.GetFunctionMap().GetArrowFromSubstrait(decoded_function.name.to_string())); + return arrow_function(scalar_fn); + } - std::vector<compute::Expression> arguments(scalar_fn.args_size()); - for (int i = 0; i < scalar_fn.args_size(); ++i) { - ARROW_ASSIGN_OR_RAISE(arguments[i], FromProto(scalar_fn.args(i), ext_set)); - } - - auto func_name = decoded_function.name.to_string(); - if (func_name != "cast") { - return compute::call(func_name, std::move(arguments)); - } else { - ARROW_ASSIGN_OR_RAISE(auto output_type_desc, - FromProto(scalar_fn.output_type(), ext_set)); - auto cast_options = compute::CastOptions::Safe(std::move(output_type_desc.first)); - return compute::call(func_name, std::move(arguments), std::move(cast_options)); - } + case substrait::Expression::kEnum: { + auto enum_expr = expr.enum_(); Review Comment: Does this convert to the string value of the enum? Can you add a small comment here explaining that. ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) != substrait_to_arrow.end()){ + substrait_to_arrow[substrait_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Substrait function already exist in the conversion map"); +} + +Result<SubstraitToArrow> FunctionMapping::GetArrowFromSubstrait(std::string name) const { + if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ + return FunctionMapping::substrait_to_arrow.at(name); + } else { + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); + } + } + +Result<ArrowToSubstrait> FunctionMapping::GetSubstraitFromArrow(std::string name) const { + if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ + return FunctionMapping::arrow_to_substrait.at(name); + } else { + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); + } + } + +std::vector<arrow::compute::Expression> substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){ + substrait::Expression value; + ExtensionSet ext_set_; Review Comment: ```suggestion ExtensionSet ext_set; ``` ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) != substrait_to_arrow.end()){ + substrait_to_arrow[substrait_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Substrait function already exist in the conversion map"); +} + +Result<SubstraitToArrow> FunctionMapping::GetArrowFromSubstrait(std::string name) const { + if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ + return FunctionMapping::substrait_to_arrow.at(name); + } else { + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); + } + } + +Result<ArrowToSubstrait> FunctionMapping::GetSubstraitFromArrow(std::string name) const { + if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ + return FunctionMapping::arrow_to_substrait.at(name); + } else { + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); + } + } + +std::vector<arrow::compute::Expression> substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){ + substrait::Expression value; + ExtensionSet ext_set_; Review Comment: This seems strange. Wouldn't this function take in an extension set as an argument? ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) != substrait_to_arrow.end()){ + substrait_to_arrow[substrait_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Substrait function already exist in the conversion map"); +} + +Result<SubstraitToArrow> FunctionMapping::GetArrowFromSubstrait(std::string name) const { + if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ + return FunctionMapping::substrait_to_arrow.at(name); + } else { + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); + } + } + +Result<ArrowToSubstrait> FunctionMapping::GetSubstraitFromArrow(std::string name) const { + if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ + return FunctionMapping::arrow_to_substrait.at(name); + } else { + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); + } + } + +std::vector<arrow::compute::Expression> substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){ + substrait::Expression value; + ExtensionSet ext_set_; + arrow::compute::Expression expression; + std::vector<compute::Expression> func_args; + for(int i=0; i<call.args_size(); ++i){ + value = call.args(i); + expression = FromProto(value, ext_set_).ValueOrDie(); + func_args.push_back(expression); + } + return func_args; +} + +substrait::Expression::ScalarFunction arrow_convert_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_){ + arrow::compute::Expression expression; + std::unique_ptr<substrait::Expression> value; + for(size_t i = 0; i<call.arguments.size(); ++i){ + expression = call.arguments[i]; + value = ToProto(expression, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*value); + } + return std::move(substrait_call); +} + +substrait::Expression::ScalarFunction arrow_convert_enum_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_, std::string overflow_handling){ + substrait::Expression::Enum options; + options.set_specified(overflow_handling); + substrait_call.add_args()->set_allocated_enum_(&options); + return arrow_convert_arguments(call, substrait_call, ext_set_); +} + + +SubstraitToArrow substrait_add_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("add", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating add"); + } else { + return arrow::compute::call("add_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } + }; + +SubstraitToArrow substrait_subtract_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("subtract", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating subtract"); + } else { + return arrow::compute::call("subtract_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + +SubstraitToArrow substrait_multiply_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("multiply", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating multiply"); + } else { + return arrow::compute::call("mutiply_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + +SubstraitToArrow substrait_divide_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("divide", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating divide"); + } else { + return arrow::compute::call("divide_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + +SubstraitToArrow substrait_modulus_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("abs", substrait_convert_arguments(call)); +}; + +ArrowToSubstrait arrow_add_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("add")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "ERROR"); + }; + +ArrowToSubstrait arrow_unchecked_add_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("add")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "SILENT"); +}; + +ArrowToSubstrait arrow_subtract_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("subtract")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "ERROR"); +}; + +ArrowToSubstrait arrow_unchecked_subtract_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("subtract")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "SILENT") ; +}; + +ArrowToSubstrait arrow_multiply_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("multiply")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "ERROR"); +}; + +ArrowToSubstrait arrow_unchecked_multiply_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("multiply")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "SILENT"); +}; + + +ArrowToSubstrait arrow_divide_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("divide")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "ERROR"); +}; + +ArrowToSubstrait arrow_unchecked_divide_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("divide")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "SILENT"); +}; + +ArrowToSubstrait arrow_abs_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("modulus")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +// Boolean Functions mappings +SubstraitToArrow substrait_not_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("invert", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_or_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("or_kleene", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_and_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("and_kleene", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_xor_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("xor", substrait_convert_arguments(call)); +}; + +ArrowToSubstrait arrow_invert_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("not")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_or_kleene_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("or")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_and_kleene_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("and")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_xor_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("xor")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +// Comparison Functions mapping +SubstraitToArrow substrait_lt_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("less", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_gt_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("greater", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_lte_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("less_equal", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_gte_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("greater_equal", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_not_equal_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("not_equal", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_equal_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("equal", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_is_null_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("is_null", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_is_not_null_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("is_valid", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_is_not_distinct_from_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + std::vector<compute::Expression> func_args = substrait_convert_arguments(call); + auto null_check_1 = arrow::compute::call("is_null", {func_args[0]}); + auto null_check_2 = arrow::compute::call("is_null", {func_args[1]}); + if(null_check_1.IsNullLiteral() && null_check_1.IsNullLiteral()){ + return arrow::compute::call("not_equal", {null_check_1, null_check_2}); + } + return arrow::compute::call("not_equal", func_args); +}; + +ArrowToSubstrait arrow_less_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("lt")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_greater_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("gt")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_less_equal_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("lte")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_greater_equal_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("gte")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_equal_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("equal")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_not_equal_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("not_equal")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_is_null_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("is_null")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_is_valid_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("is_not_null")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +// Strings function mapping +SubstraitToArrow substrait_like_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + return arrow::compute::call("match_like", {func_args[0]}, compute::MatchSubstringOptions(func_args[1].ToString())); +}; + +SubstraitToArrow substrait_substring_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + auto start = func_args[1].literal()->scalar_as<Int64Scalar>(); + auto stop = func_args[2].literal()->scalar_as<Int64Scalar>(); + return arrow::compute::call("utf8_slice_codeunits", {func_args[0]}, compute::SliceOptions(static_cast<int64_t>(start.value), static_cast<int64_t>(stop.value))); +}; + +SubstraitToArrow substrait_concat_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + arrow::StringBuilder builder; + builder.Append(func_args[0].ToString()); + builder.Append(func_args[1].ToString()); + auto strings_datum = arrow::Datum(*builder.Finish()); + auto separator_datum = arrow::Datum(""); + return arrow::compute::call("binary_join", {arrow::compute::Expression(strings_datum), arrow::compute::Expression(separator_datum)}); +}; + +ArrowToSubstrait arrow_match_like_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("like")); + substrait_call.set_function_reference(function_reference); + + arrow::compute::Expression expression_1, expression_2; + std::unique_ptr<substrait::Expression> string_1, string_2; + expression_1 = call.arguments[0]; + string_1 = ToProto(expression_1, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*string_1); + + auto pattern_string = std::dynamic_pointer_cast<compute::MatchSubstringOptions>(call.options)->pattern; + expression_2 = arrow::compute::Expression(arrow::Datum(pattern_string)); + string_2 = ToProto(expression_2, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*string_2); + + return std::move(substrait_call); +}; + +ArrowToSubstrait arrow_utf8_slice_codeunits_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("substring")); + substrait_call.set_function_reference(function_reference); + arrow::compute::Expression expression_1, expression_2, expression_3; + std::unique_ptr<substrait::Expression> string, start, stop; + expression_1 = call.arguments[0]; + string = ToProto(expression_1, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*string); + + auto start_index = std::dynamic_pointer_cast<compute::SliceOptions>(call.options)->start; + auto stop_index = std::dynamic_pointer_cast<compute::SliceOptions>(call.options)->stop; + expression_2 = arrow::compute::Expression(arrow::Datum(start_index)); + expression_3 = arrow::compute::Expression(arrow::Datum(stop_index)); + start = ToProto(expression_2, ext_set_).ValueOrDie(); + stop = ToProto(expression_3, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*start); + substrait_call.add_args()->CopyFrom(*stop); + + return std::move(substrait_call); +}; + +ArrowToSubstrait arrow_binary_join_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("concat")); + substrait_call.set_function_reference(function_reference); + arrow::compute::Expression expression_1, expression_2; + std::unique_ptr<substrait::Expression> string_1, string_2; + + auto strings_list = call.arguments[0].literal()->make_array(); + expression_1 = arrow::compute::Expression(*(strings_list->GetScalar(0))); + expression_2 = arrow::compute::Expression(*(strings_list->GetScalar(1))); + + string_1 = ToProto(expression_1, ext_set_).ValueOrDie(); + string_2 = ToProto(expression_2, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*string_1); + substrait_call.add_args()->CopyFrom(*string_2); + return std::move(substrait_call); +}; + +// Cast function mapping +SubstraitToArrow substrait_cast_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + ExtensionSet ext_set_; + ARROW_ASSIGN_OR_RAISE(auto output_type_desc, + FromProto(call.output_type(), ext_set_)); + auto cast_options = compute::CastOptions::Safe(std::move(output_type_desc.first)); + return compute::call("cast", {substrait_convert_arguments(call)[0]}, std::move(cast_options)); +}; + +ArrowToSubstrait arrow_cast_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("cast")); + substrait_call.set_function_reference(function_reference); + + auto arrow_to_type = std::dynamic_pointer_cast<compute::CastOptions>(call.options)->to_type; + ARROW_ASSIGN_OR_RAISE(auto substrait_to_type, ToProto(*arrow_to_type, false, ext_set_)); + substrait_call.set_allocated_output_type(substrait_to_type.get()); + + auto expression = call.arguments[0]; + ARROW_ASSIGN_OR_RAISE(auto value, ToProto(expression, ext_set_)); + substrait_call.add_args()->CopyFrom(*value); + + return substrait_call; +}; + +// Datetime functions mapping +SubstraitToArrow substrait_extract_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "YEAR"){ + return arrow::compute::call("year", {func_args[1]}); + } else if (func_args[0].ToString() == "MONTH") { + return arrow::compute::call("month", {func_args[1]}); + } else if (func_args[0].ToString() == "DAY") { + return arrow::compute::call("day", {func_args[1]}); + } else { + return arrow::compute::call("second", {func_args[1]}); + } +}; + +ArrowToSubstrait arrow_year_to_arrow = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("extract")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "YEAR"); +}; + +ArrowToSubstrait arrow_month_to_arrow = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("extract")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "MONTH"); +}; + +ArrowToSubstrait arrow_day_to_arrow = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("extract")); + substrait_call.set_function_reference(function_reference); Review Comment: All of these calls to `EncodeFunction` seem pretty repetitive. Is there any way we can move this into the part that calls `GetArrowToSubstrait`? Also, I don't see anything today that calls `GetArrowToSubstrait` ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) != substrait_to_arrow.end()){ + substrait_to_arrow[substrait_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Substrait function already exist in the conversion map"); +} + +Result<SubstraitToArrow> FunctionMapping::GetArrowFromSubstrait(std::string name) const { + if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ + return FunctionMapping::substrait_to_arrow.at(name); + } else { + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); + } + } + +Result<ArrowToSubstrait> FunctionMapping::GetSubstraitFromArrow(std::string name) const { + if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ + return FunctionMapping::arrow_to_substrait.at(name); + } else { + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); + } + } + +std::vector<arrow::compute::Expression> substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){ + substrait::Expression value; + ExtensionSet ext_set_; + arrow::compute::Expression expression; + std::vector<compute::Expression> func_args; + for(int i=0; i<call.args_size(); ++i){ + value = call.args(i); + expression = FromProto(value, ext_set_).ValueOrDie(); + func_args.push_back(expression); + } + return func_args; +} + +substrait::Expression::ScalarFunction arrow_convert_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_){ Review Comment: Style guide prefers passing mutable objects by pointer instead of mutable references. ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) != substrait_to_arrow.end()){ + substrait_to_arrow[substrait_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Substrait function already exist in the conversion map"); +} + +Result<SubstraitToArrow> FunctionMapping::GetArrowFromSubstrait(std::string name) const { + if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ + return FunctionMapping::substrait_to_arrow.at(name); + } else { + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); + } + } + +Result<ArrowToSubstrait> FunctionMapping::GetSubstraitFromArrow(std::string name) const { + if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ + return FunctionMapping::arrow_to_substrait.at(name); + } else { + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); + } + } + +std::vector<arrow::compute::Expression> substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){ + substrait::Expression value; + ExtensionSet ext_set_; + arrow::compute::Expression expression; + std::vector<compute::Expression> func_args; + for(int i=0; i<call.args_size(); ++i){ + value = call.args(i); + expression = FromProto(value, ext_set_).ValueOrDie(); Review Comment: Don't use `ValueOrDie` if it could possibly fail. Return a `Result` instead and use `ARROW_ASSIGN_OR_RAISE` ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ Review Comment: This logic seems backwards to me...wouldn't `umap.find(...) != umap.end()` mean the item already existed? ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) != substrait_to_arrow.end()){ + substrait_to_arrow[substrait_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Substrait function already exist in the conversion map"); +} + +Result<SubstraitToArrow> FunctionMapping::GetArrowFromSubstrait(std::string name) const { + if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ + return FunctionMapping::substrait_to_arrow.at(name); + } else { + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); + } + } + +Result<ArrowToSubstrait> FunctionMapping::GetSubstraitFromArrow(std::string name) const { + if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ + return FunctionMapping::arrow_to_substrait.at(name); + } else { + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); + } + } + +std::vector<arrow::compute::Expression> substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){ + substrait::Expression value; + ExtensionSet ext_set_; + arrow::compute::Expression expression; + std::vector<compute::Expression> func_args; + for(int i=0; i<call.args_size(); ++i){ + value = call.args(i); + expression = FromProto(value, ext_set_).ValueOrDie(); + func_args.push_back(expression); + } + return func_args; +} + +substrait::Expression::ScalarFunction arrow_convert_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_){ Review Comment: ```suggestion substrait::Expression::ScalarFunction ConvertArrowArguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction* substrait_call, ExtensionSet* ext_set){ ``` ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) != substrait_to_arrow.end()){ + substrait_to_arrow[substrait_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Substrait function already exist in the conversion map"); +} + +Result<SubstraitToArrow> FunctionMapping::GetArrowFromSubstrait(std::string name) const { + if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ + return FunctionMapping::substrait_to_arrow.at(name); + } else { + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); + } + } + +Result<ArrowToSubstrait> FunctionMapping::GetSubstraitFromArrow(std::string name) const { + if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ + return FunctionMapping::arrow_to_substrait.at(name); + } else { + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); + } + } + +std::vector<arrow::compute::Expression> substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){ Review Comment: ```suggestion std::vector<arrow::compute::Expression> ConvertSubstraitArguments(const substrait::Expression::ScalarFunction& call){ ``` ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) != substrait_to_arrow.end()){ + substrait_to_arrow[substrait_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Substrait function already exist in the conversion map"); +} + +Result<SubstraitToArrow> FunctionMapping::GetArrowFromSubstrait(std::string name) const { + if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ + return FunctionMapping::substrait_to_arrow.at(name); + } else { + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); + } + } + +Result<ArrowToSubstrait> FunctionMapping::GetSubstraitFromArrow(std::string name) const { + if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ + return FunctionMapping::arrow_to_substrait.at(name); + } else { + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); + } + } + +std::vector<arrow::compute::Expression> substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){ + substrait::Expression value; + ExtensionSet ext_set_; + arrow::compute::Expression expression; + std::vector<compute::Expression> func_args; + for(int i=0; i<call.args_size(); ++i){ + value = call.args(i); + expression = FromProto(value, ext_set_).ValueOrDie(); + func_args.push_back(expression); + } + return func_args; +} + +substrait::Expression::ScalarFunction arrow_convert_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_){ + arrow::compute::Expression expression; + std::unique_ptr<substrait::Expression> value; + for(size_t i = 0; i<call.arguments.size(); ++i){ + expression = call.arguments[i]; + value = ToProto(expression, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*value); + } + return std::move(substrait_call); +} + +substrait::Expression::ScalarFunction arrow_convert_enum_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_, std::string overflow_handling){ Review Comment: ```suggestion substrait::Expression::ScalarFunction ConvertArrowEnumArguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set, std::string overflow_handling){ ``` ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) != substrait_to_arrow.end()){ + substrait_to_arrow[substrait_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Substrait function already exist in the conversion map"); +} + +Result<SubstraitToArrow> FunctionMapping::GetArrowFromSubstrait(std::string name) const { + if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ + return FunctionMapping::substrait_to_arrow.at(name); + } else { + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); + } + } + +Result<ArrowToSubstrait> FunctionMapping::GetSubstraitFromArrow(std::string name) const { + if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ + return FunctionMapping::arrow_to_substrait.at(name); + } else { + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); + } + } + +std::vector<arrow::compute::Expression> substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){ + substrait::Expression value; + ExtensionSet ext_set_; + arrow::compute::Expression expression; + std::vector<compute::Expression> func_args; + for(int i=0; i<call.args_size(); ++i){ + value = call.args(i); + expression = FromProto(value, ext_set_).ValueOrDie(); + func_args.push_back(expression); + } + return func_args; +} + +substrait::Expression::ScalarFunction arrow_convert_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_){ + arrow::compute::Expression expression; + std::unique_ptr<substrait::Expression> value; + for(size_t i = 0; i<call.arguments.size(); ++i){ + expression = call.arguments[i]; + value = ToProto(expression, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*value); + } + return std::move(substrait_call); +} + +substrait::Expression::ScalarFunction arrow_convert_enum_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_, std::string overflow_handling){ + substrait::Expression::Enum options; + options.set_specified(overflow_handling); Review Comment: overflow_handling seems like an odd name given this is a generic function ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) != substrait_to_arrow.end()){ + substrait_to_arrow[substrait_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Substrait function already exist in the conversion map"); +} + +Result<SubstraitToArrow> FunctionMapping::GetArrowFromSubstrait(std::string name) const { + if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ + return FunctionMapping::substrait_to_arrow.at(name); + } else { + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); + } + } + +Result<ArrowToSubstrait> FunctionMapping::GetSubstraitFromArrow(std::string name) const { + if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ + return FunctionMapping::arrow_to_substrait.at(name); + } else { + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); + } + } + +std::vector<arrow::compute::Expression> substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){ + substrait::Expression value; + ExtensionSet ext_set_; + arrow::compute::Expression expression; + std::vector<compute::Expression> func_args; + for(int i=0; i<call.args_size(); ++i){ + value = call.args(i); + expression = FromProto(value, ext_set_).ValueOrDie(); + func_args.push_back(expression); + } + return func_args; +} + +substrait::Expression::ScalarFunction arrow_convert_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_){ + arrow::compute::Expression expression; + std::unique_ptr<substrait::Expression> value; + for(size_t i = 0; i<call.arguments.size(); ++i){ + expression = call.arguments[i]; + value = ToProto(expression, ext_set_).ValueOrDie(); Review Comment: Again, don't use `ValueOrDie`. ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) != substrait_to_arrow.end()){ + substrait_to_arrow[substrait_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Substrait function already exist in the conversion map"); +} + +Result<SubstraitToArrow> FunctionMapping::GetArrowFromSubstrait(std::string name) const { + if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ + return FunctionMapping::substrait_to_arrow.at(name); + } else { + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); + } + } + +Result<ArrowToSubstrait> FunctionMapping::GetSubstraitFromArrow(std::string name) const { + if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ + return FunctionMapping::arrow_to_substrait.at(name); + } else { + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); + } + } + +std::vector<arrow::compute::Expression> substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){ + substrait::Expression value; + ExtensionSet ext_set_; + arrow::compute::Expression expression; + std::vector<compute::Expression> func_args; + for(int i=0; i<call.args_size(); ++i){ + value = call.args(i); + expression = FromProto(value, ext_set_).ValueOrDie(); + func_args.push_back(expression); + } + return func_args; +} + +substrait::Expression::ScalarFunction arrow_convert_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_){ + arrow::compute::Expression expression; + std::unique_ptr<substrait::Expression> value; + for(size_t i = 0; i<call.arguments.size(); ++i){ + expression = call.arguments[i]; + value = ToProto(expression, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*value); + } + return std::move(substrait_call); +} + +substrait::Expression::ScalarFunction arrow_convert_enum_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_, std::string overflow_handling){ + substrait::Expression::Enum options; + options.set_specified(overflow_handling); + substrait_call.add_args()->set_allocated_enum_(&options); + return arrow_convert_arguments(call, substrait_call, ext_set_); +} + + +SubstraitToArrow substrait_add_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("add", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating add"); + } else { + return arrow::compute::call("add_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } + }; + +SubstraitToArrow substrait_subtract_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("subtract", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating subtract"); + } else { + return arrow::compute::call("subtract_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + +SubstraitToArrow substrait_multiply_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("multiply", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating multiply"); + } else { + return arrow::compute::call("mutiply_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + +SubstraitToArrow substrait_divide_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("divide", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating divide"); + } else { + return arrow::compute::call("divide_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + +SubstraitToArrow substrait_modulus_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("abs", substrait_convert_arguments(call)); +}; + +ArrowToSubstrait arrow_add_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { Review Comment: There's a lot of places where you have `ext_set_` and it should probably be `ext_set`. For the sake of brevity I'm not going to mark them all. ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -288,6 +808,11 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return Status::OK(); } + Status RegisterFunctionMapping(Id id, SubstraitToArrow conversion_func) override { + DCHECK_OK(functions_map.AddSubstraitToArrow(id.name.to_string(), conversion_func)); + return RegisterFunction(id, id.name.to_string()); Review Comment: It seems a little odd that we need two maps. What happens if two functions exist with the same name but different URIs? Thinking on this longer, maybe `substrait_to_arrow` should replace the map in the extension id registry (that gets updated by the call to `RegisterFunction`?) ########## cpp/src/arrow/engine/substrait/extension_set.cc: ########## @@ -204,6 +209,521 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } +Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, ArrowToSubstrait conversion_func){ + if (arrow_to_substrait.find(arrow_function_name) != arrow_to_substrait.end()){ + arrow_to_substrait[arrow_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Arrow function already exist in the conversion map"); +} + +Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, SubstraitToArrow conversion_func){ + if (substrait_to_arrow.find(substrait_function_name) != substrait_to_arrow.end()){ + substrait_to_arrow[substrait_function_name] = conversion_func; + return Status::OK(); + } + return Status::AlreadyExists("Substrait function already exist in the conversion map"); +} + +Result<SubstraitToArrow> FunctionMapping::GetArrowFromSubstrait(std::string name) const { + if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ + return FunctionMapping::substrait_to_arrow.at(name); + } else { + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); + } + } + +Result<ArrowToSubstrait> FunctionMapping::GetSubstraitFromArrow(std::string name) const { + if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ + return FunctionMapping::arrow_to_substrait.at(name); + } else { + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); + } + } + +std::vector<arrow::compute::Expression> substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){ + substrait::Expression value; + ExtensionSet ext_set_; + arrow::compute::Expression expression; + std::vector<compute::Expression> func_args; + for(int i=0; i<call.args_size(); ++i){ + value = call.args(i); + expression = FromProto(value, ext_set_).ValueOrDie(); + func_args.push_back(expression); + } + return func_args; +} + +substrait::Expression::ScalarFunction arrow_convert_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_){ + arrow::compute::Expression expression; + std::unique_ptr<substrait::Expression> value; + for(size_t i = 0; i<call.arguments.size(); ++i){ + expression = call.arguments[i]; + value = ToProto(expression, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*value); + } + return std::move(substrait_call); +} + +substrait::Expression::ScalarFunction arrow_convert_enum_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_, std::string overflow_handling){ + substrait::Expression::Enum options; + options.set_specified(overflow_handling); + substrait_call.add_args()->set_allocated_enum_(&options); + return arrow_convert_arguments(call, substrait_call, ext_set_); +} + + +SubstraitToArrow substrait_add_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("add", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating add"); + } else { + return arrow::compute::call("add_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } + }; + +SubstraitToArrow substrait_subtract_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("subtract", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating subtract"); + } else { + return arrow::compute::call("subtract_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + +SubstraitToArrow substrait_multiply_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("multiply", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating multiply"); + } else { + return arrow::compute::call("mutiply_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + +SubstraitToArrow substrait_divide_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("divide", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating divide"); + } else { + return arrow::compute::call("divide_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + +SubstraitToArrow substrait_modulus_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("abs", substrait_convert_arguments(call)); +}; + +ArrowToSubstrait arrow_add_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("add")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "ERROR"); + }; + +ArrowToSubstrait arrow_unchecked_add_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("add")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "SILENT"); +}; + +ArrowToSubstrait arrow_subtract_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("subtract")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "ERROR"); +}; + +ArrowToSubstrait arrow_unchecked_subtract_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("subtract")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "SILENT") ; +}; + +ArrowToSubstrait arrow_multiply_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("multiply")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "ERROR"); +}; + +ArrowToSubstrait arrow_unchecked_multiply_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("multiply")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "SILENT"); +}; + + +ArrowToSubstrait arrow_divide_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("divide")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "ERROR"); +}; + +ArrowToSubstrait arrow_unchecked_divide_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("divide")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_enum_arguments(call, substrait_call, ext_set_, "SILENT"); +}; + +ArrowToSubstrait arrow_abs_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("modulus")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +// Boolean Functions mappings +SubstraitToArrow substrait_not_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("invert", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_or_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("or_kleene", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_and_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("and_kleene", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_xor_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("xor", substrait_convert_arguments(call)); +}; + +ArrowToSubstrait arrow_invert_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("not")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_or_kleene_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("or")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_and_kleene_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("and")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_xor_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("xor")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +// Comparison Functions mapping +SubstraitToArrow substrait_lt_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("less", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_gt_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("greater", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_lte_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("less_equal", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_gte_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("greater_equal", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_not_equal_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("not_equal", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_equal_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("equal", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_is_null_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("is_null", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_is_not_null_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + return arrow::compute::call("is_valid", substrait_convert_arguments(call)); +}; + +SubstraitToArrow substrait_is_not_distinct_from_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + std::vector<compute::Expression> func_args = substrait_convert_arguments(call); + auto null_check_1 = arrow::compute::call("is_null", {func_args[0]}); + auto null_check_2 = arrow::compute::call("is_null", {func_args[1]}); + if(null_check_1.IsNullLiteral() && null_check_1.IsNullLiteral()){ + return arrow::compute::call("not_equal", {null_check_1, null_check_2}); + } + return arrow::compute::call("not_equal", func_args); +}; + +ArrowToSubstrait arrow_less_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("lt")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_greater_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("gt")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_less_equal_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("lte")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_greater_equal_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("gte")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_equal_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("equal")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_not_equal_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("not_equal")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_is_null_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("is_null")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +ArrowToSubstrait arrow_is_valid_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("is_not_null")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arguments(call, substrait_call, ext_set_); +}; + +// Strings function mapping +SubstraitToArrow substrait_like_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + return arrow::compute::call("match_like", {func_args[0]}, compute::MatchSubstringOptions(func_args[1].ToString())); +}; + +SubstraitToArrow substrait_substring_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + auto start = func_args[1].literal()->scalar_as<Int64Scalar>(); + auto stop = func_args[2].literal()->scalar_as<Int64Scalar>(); + return arrow::compute::call("utf8_slice_codeunits", {func_args[0]}, compute::SliceOptions(static_cast<int64_t>(start.value), static_cast<int64_t>(stop.value))); +}; + +SubstraitToArrow substrait_concat_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + arrow::StringBuilder builder; + builder.Append(func_args[0].ToString()); + builder.Append(func_args[1].ToString()); + auto strings_datum = arrow::Datum(*builder.Finish()); + auto separator_datum = arrow::Datum(""); + return arrow::compute::call("binary_join", {arrow::compute::Expression(strings_datum), arrow::compute::Expression(separator_datum)}); +}; + +ArrowToSubstrait arrow_match_like_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("like")); + substrait_call.set_function_reference(function_reference); + + arrow::compute::Expression expression_1, expression_2; + std::unique_ptr<substrait::Expression> string_1, string_2; + expression_1 = call.arguments[0]; + string_1 = ToProto(expression_1, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*string_1); + + auto pattern_string = std::dynamic_pointer_cast<compute::MatchSubstringOptions>(call.options)->pattern; + expression_2 = arrow::compute::Expression(arrow::Datum(pattern_string)); + string_2 = ToProto(expression_2, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*string_2); + + return std::move(substrait_call); +}; + +ArrowToSubstrait arrow_utf8_slice_codeunits_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("substring")); + substrait_call.set_function_reference(function_reference); + arrow::compute::Expression expression_1, expression_2, expression_3; + std::unique_ptr<substrait::Expression> string, start, stop; + expression_1 = call.arguments[0]; + string = ToProto(expression_1, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*string); + + auto start_index = std::dynamic_pointer_cast<compute::SliceOptions>(call.options)->start; + auto stop_index = std::dynamic_pointer_cast<compute::SliceOptions>(call.options)->stop; + expression_2 = arrow::compute::Expression(arrow::Datum(start_index)); + expression_3 = arrow::compute::Expression(arrow::Datum(stop_index)); + start = ToProto(expression_2, ext_set_).ValueOrDie(); + stop = ToProto(expression_3, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*start); + substrait_call.add_args()->CopyFrom(*stop); + + return std::move(substrait_call); +}; + +ArrowToSubstrait arrow_binary_join_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("concat")); + substrait_call.set_function_reference(function_reference); + arrow::compute::Expression expression_1, expression_2; + std::unique_ptr<substrait::Expression> string_1, string_2; + + auto strings_list = call.arguments[0].literal()->make_array(); + expression_1 = arrow::compute::Expression(*(strings_list->GetScalar(0))); + expression_2 = arrow::compute::Expression(*(strings_list->GetScalar(1))); + + string_1 = ToProto(expression_1, ext_set_).ValueOrDie(); + string_2 = ToProto(expression_2, ext_set_).ValueOrDie(); + substrait_call.add_args()->CopyFrom(*string_1); + substrait_call.add_args()->CopyFrom(*string_2); + return std::move(substrait_call); +}; + +// Cast function mapping +SubstraitToArrow substrait_cast_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + ExtensionSet ext_set_; + ARROW_ASSIGN_OR_RAISE(auto output_type_desc, + FromProto(call.output_type(), ext_set_)); + auto cast_options = compute::CastOptions::Safe(std::move(output_type_desc.first)); + return compute::call("cast", {substrait_convert_arguments(call)[0]}, std::move(cast_options)); +}; + +ArrowToSubstrait arrow_cast_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("cast")); + substrait_call.set_function_reference(function_reference); + + auto arrow_to_type = std::dynamic_pointer_cast<compute::CastOptions>(call.options)->to_type; + ARROW_ASSIGN_OR_RAISE(auto substrait_to_type, ToProto(*arrow_to_type, false, ext_set_)); + substrait_call.set_allocated_output_type(substrait_to_type.get()); + + auto expression = call.arguments[0]; + ARROW_ASSIGN_OR_RAISE(auto value, ToProto(expression, ext_set_)); + substrait_call.add_args()->CopyFrom(*value); + + return substrait_call; +}; + +// Datetime functions mapping +SubstraitToArrow substrait_extract_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "YEAR"){ + return arrow::compute::call("year", {func_args[1]}); + } else if (func_args[0].ToString() == "MONTH") { + return arrow::compute::call("month", {func_args[1]}); + } else if (func_args[0].ToString() == "DAY") { + return arrow::compute::call("day", {func_args[1]}); + } else { + return arrow::compute::call("second", {func_args[1]}); + } +}; + +ArrowToSubstrait arrow_year_to_arrow = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> { Review Comment: `arrow_...to_arrow`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org