kou commented on code in PR #37787: URL: https://github.com/apache/arrow/pull/37787#discussion_r1335320160
########## cpp/src/gandiva/extension_tests/CMakeLists.txt: ########## @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(NO_TESTS) + return() +endif() + +# copy the testing data into the build directory +add_custom_target(extension-tests-data + COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_BINARY_DIR}/gandiva_extension_tests) Review Comment: It seems that `CMAKE_CURRENT_BINARY_DIR` is better. Why do we want to the top build directory for this? ########## cpp/src/gandiva/extension_tests/CMakeLists.txt: ########## @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(NO_TESTS) + return() +endif() + +# copy the testing data into the build directory +add_custom_target(extension-tests-data + COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_BINARY_DIR}/gandiva_extension_tests) + +include(../cmake/GenerateBitcode.cmake) + +set(TEST_EXT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/extended_funcs) +set(TEST_PRECOMPILED_SRCS ${TEST_EXT_DIR}/multiply_by_two.cc) +generate_bitcode("${TEST_PRECOMPILED_SRCS}" + "../../../gandiva_extension_tests/extended_funcs/" TEST_BC_FILES) Review Comment: `CMAKE_BINARY_DIR`? ```suggestion "${CMAKE_BINARY_DIR}/gandiva_extension_tests/extended_funcs/" TEST_BC_FILES) ``` ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; + if (!functions.IsArray()) { + return Status::TypeError("'functions' property is expected to be a JSON array"); + } + + std::vector<NativeFunction> funcs; + for (const auto& func : functions.GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto name, GetString(func, "name")); + ARROW_ASSIGN_OR_RAISE(auto aliases, GetAliases(func)); + ARROW_ASSIGN_OR_RAISE(DataTypeVector param_types, ParseParamTypes(func)); + ARROW_ASSIGN_OR_RAISE(auto ret_type, ParseDataType(func["return_type"])); + ARROW_ASSIGN_OR_RAISE(ResultNullableType result_nullable_type, + ParseResultNullable(func)); + ARROW_ASSIGN_OR_RAISE(auto pc_name, GetString(func, "pc_name")); + int32_t flags = GetFlags(func); + funcs.emplace_back(name, aliases, param_types, ret_type, result_nullable_type, + pc_name, flags); + } + return funcs; + } + + private: + static arrow::Result<std::string> GetString(const rj::GenericValue<rj::UTF8<>>& func, + const std::string& key) { + if (!func.HasMember(key.c_str())) { + return Status::TypeError("'" + key + "'" + " property is missing"); + } + if (!func[key.c_str()].IsString()) { + return Status::TypeError("'" + key + "'" + " property should be a string"); + } + return func[key.c_str()].GetString(); + } + + static arrow::Result<ResultNullableType> ParseResultNullable( + const rj::GenericValue<rj::UTF8<>>& func) { + std::string nullable; + if (!func.HasMember("result_nullable")) { + nullable = "ifnull"; + } else { + if (!func["result_nullable"].IsString()) { + return Status::TypeError("result_nullable property should be a string"); + } + nullable = func["result_nullable"].GetString(); + } + if (nullable == "ifnull") { Review Comment: `if_null` may be better. ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; + if (!functions.IsArray()) { + return Status::TypeError("'functions' property is expected to be a JSON array"); + } + + std::vector<NativeFunction> funcs; + for (const auto& func : functions.GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto name, GetString(func, "name")); + ARROW_ASSIGN_OR_RAISE(auto aliases, GetAliases(func)); + ARROW_ASSIGN_OR_RAISE(DataTypeVector param_types, ParseParamTypes(func)); + ARROW_ASSIGN_OR_RAISE(auto ret_type, ParseDataType(func["return_type"])); + ARROW_ASSIGN_OR_RAISE(ResultNullableType result_nullable_type, + ParseResultNullable(func)); + ARROW_ASSIGN_OR_RAISE(auto pc_name, GetString(func, "pc_name")); + int32_t flags = GetFlags(func); + funcs.emplace_back(name, aliases, param_types, ret_type, result_nullable_type, + pc_name, flags); + } + return funcs; + } + + private: + static arrow::Result<std::string> GetString(const rj::GenericValue<rj::UTF8<>>& func, + const std::string& key) { + if (!func.HasMember(key.c_str())) { + return Status::TypeError("'" + key + "'" + " property is missing"); + } + if (!func[key.c_str()].IsString()) { + return Status::TypeError("'" + key + "'" + " property should be a string"); + } + return func[key.c_str()].GetString(); + } + + static arrow::Result<ResultNullableType> ParseResultNullable( + const rj::GenericValue<rj::UTF8<>>& func) { + std::string nullable; + if (!func.HasMember("result_nullable")) { + nullable = "ifnull"; + } else { + if (!func["result_nullable"].IsString()) { + return Status::TypeError("result_nullable property should be a string"); + } + nullable = func["result_nullable"].GetString(); + } + if (nullable == "ifnull") { + return ResultNullableType::kResultNullIfNull; + } else if (nullable == "never") { + return ResultNullableType::kResultNullNever; + } else if (nullable == "internal") { + return ResultNullableType::kResultNullInternal; + } else { + return Status::TypeError("Unsupported result_nullable value: " + nullable + + ". Only ifnull/never/internal are supported"); + } + } + static int32_t GetFlags(const rj::GenericValue<rj::UTF8<>>& func) { + int32_t flags = 0; + for (auto const& [flag_name, flag_value] : + {std::make_pair("needs_context", NativeFunction::kNeedsContext), + std::make_pair("needs_function_holder", NativeFunction::kNeedsFunctionHolder), + std::make_pair("can_return_errors", NativeFunction::kCanReturnErrors)}) { + if (func.HasMember(flag_name) && func[flag_name].GetBool()) { + flags |= flag_value; + } + } + return flags; + } + + static arrow::Result<std::vector<std::string>> GetAliases( + const rj::GenericValue<rj::UTF8<>>& func) { + std::vector<std::string> aliases; + if (!func.HasMember("aliases")) { + return aliases; + } + if (func["aliases"].IsArray()) { + for (const auto& alias : func["aliases"].GetArray()) { + aliases.emplace_back(alias.GetString()); + } + } else { + return Status::TypeError("'aliases' property is expected to be a JSON array"); + } + return aliases; + } + + static arrow::Result<arrow::DataTypeVector> ParseParamTypes( + const rj::GenericValue<rj::UTF8<>>& func) { + arrow::DataTypeVector param_types; + if (!func.HasMember("param_types")) { + return param_types; + } + if (!func["param_types"].IsArray()) { + return Status::TypeError("'param_types' property is expected to be a JSON array"); + } + for (const auto& param_type : func["param_types"].GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto type, ParseDataType(param_type)) + param_types.push_back(type); + } + return param_types; + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseTimestampDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("unit")) { + return Status::TypeError("'unit' property is required for timestamp data type"); + } + const std::string unit_name = data_type["unit"].GetString(); + arrow::TimeUnit::type unit; + if (unit_name == "second") { + unit = arrow::TimeUnit::SECOND; + } else if (unit_name == "milli") { + unit = arrow::TimeUnit::MILLI; + } else if (unit_name == "micro") { + unit = arrow::TimeUnit::MICRO; + } else if (unit_name == "nano") { + unit = arrow::TimeUnit::NANO; + } else { + return Status::TypeError("Unsupported timestamp unit name: ", unit_name); + } + return arrow::timestamp(unit); + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseDecimalDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("precision") || !data_type["precision"].IsInt()) { + return Status::TypeError( + "'precision' property is required for decimal data type and should be an " + "integer"); + } + if (!data_type.HasMember("scale") || !data_type["scale"].IsInt()) { + return Status::TypeError( + "'scale' property is required for decimal data type and should be an integer"); + } + auto precision = data_type["precision"].GetInt(); + auto scale = data_type["scale"].GetInt(); + const std::string type_name = data_type["type"].GetString(); + if (type_name == "decimal128") { + return arrow::decimal128(precision, scale); + } else if (type_name == "decimal256") { + return arrow::decimal256(precision, scale); + } Review Comment: Do we need `else` for an invalid value? ########## cpp/src/gandiva/engine.cc: ########## @@ -220,6 +224,42 @@ static void SetDataLayout(llvm::Module* module) { } // end of the mofified method from MLIR +// Loading extended IR files from the given directory +// all .bc files under the given directory will be loaded and parsed +Status Engine::LoadExtendedPreCompiledIR(const std::filesystem::path& dir_path) { + for (const auto& entry : std::filesystem::directory_iterator(dir_path)) { + if (entry.is_regular_file() && entry.path().extension() == ".bc") { + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> buffer_or_error = + llvm::MemoryBuffer::getFile(entry.path().string()); + + ARROW_RETURN_IF(!buffer_or_error, + Status::CodeGenError("Could not load module from IR file: ", + entry.path().string() + " Error: " + + buffer_or_error.getError().message())); + + std::unique_ptr<llvm::MemoryBuffer> buffer = std::move(buffer_or_error.get()); + + llvm::Expected<std::unique_ptr<llvm::Module>> module_or_error = + llvm::parseBitcodeFile(buffer->getMemBufferRef(), *context()); + if (!module_or_error) { + std::string str; + llvm::raw_string_ostream stream(str); + stream << module_or_error.takeError(); + return Status::CodeGenError("Failed to parse bitcode file: " + + entry.path().string() + " Error: " + stream.str()); + } + std::unique_ptr<llvm::Module> ir_module = std::move(module_or_error.get()); Review Comment: ```suggestion auto ir_module = std::move(module_or_error.get()); ``` ########## cpp/src/gandiva/engine.cc: ########## @@ -220,6 +224,42 @@ static void SetDataLayout(llvm::Module* module) { } // end of the mofified method from MLIR +// Loading extended IR files from the given directory +// all .bc files under the given directory will be loaded and parsed +Status Engine::LoadExtendedPreCompiledIR(const std::filesystem::path& dir_path) { + for (const auto& entry : std::filesystem::directory_iterator(dir_path)) { + if (entry.is_regular_file() && entry.path().extension() == ".bc") { + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> buffer_or_error = Review Comment: Can we use `auto` here? ```suggestion auto buffer_or_error = ``` ########## cpp/src/gandiva/engine.cc: ########## @@ -137,6 +137,10 @@ Status Engine::LoadFunctionIRs() { if (!functions_loaded_) { ARROW_RETURN_NOT_OK(LoadPreCompiledIR()); ARROW_RETURN_NOT_OK(DecimalIR::AddFunctions(this)); + const char* ext_dir_env = std::getenv("GANDIVA_EXTENSION_DIR"); Review Comment: How about using `::arrow::internal::GetEnvVarNative()` instead? Can we use `${prefix}/lib/gandiva/extension/` or something as the default extension directory? ########## cpp/src/gandiva/engine.cc: ########## @@ -220,6 +224,42 @@ static void SetDataLayout(llvm::Module* module) { } // end of the mofified method from MLIR +// Loading extended IR files from the given directory +// all .bc files under the given directory will be loaded and parsed +Status Engine::LoadExtendedPreCompiledIR(const std::filesystem::path& dir_path) { + for (const auto& entry : std::filesystem::directory_iterator(dir_path)) { + if (entry.is_regular_file() && entry.path().extension() == ".bc") { + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> buffer_or_error = + llvm::MemoryBuffer::getFile(entry.path().string()); + + ARROW_RETURN_IF(!buffer_or_error, + Status::CodeGenError("Could not load module from IR file: ", + entry.path().string() + " Error: " + + buffer_or_error.getError().message())); + + std::unique_ptr<llvm::MemoryBuffer> buffer = std::move(buffer_or_error.get()); + + llvm::Expected<std::unique_ptr<llvm::Module>> module_or_error = Review Comment: ```suggestion auto module_or_error = ``` ########## cpp/src/gandiva/engine.cc: ########## @@ -220,6 +224,42 @@ static void SetDataLayout(llvm::Module* module) { } // end of the mofified method from MLIR +// Loading extended IR files from the given directory +// all .bc files under the given directory will be loaded and parsed +Status Engine::LoadExtendedPreCompiledIR(const std::filesystem::path& dir_path) { + for (const auto& entry : std::filesystem::directory_iterator(dir_path)) { + if (entry.is_regular_file() && entry.path().extension() == ".bc") { + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> buffer_or_error = + llvm::MemoryBuffer::getFile(entry.path().string()); + + ARROW_RETURN_IF(!buffer_or_error, + Status::CodeGenError("Could not load module from IR file: ", + entry.path().string() + " Error: " + + buffer_or_error.getError().message())); + + std::unique_ptr<llvm::MemoryBuffer> buffer = std::move(buffer_or_error.get()); + + llvm::Expected<std::unique_ptr<llvm::Module>> module_or_error = + llvm::parseBitcodeFile(buffer->getMemBufferRef(), *context()); + if (!module_or_error) { + std::string str; + llvm::raw_string_ostream stream(str); + stream << module_or_error.takeError(); + return Status::CodeGenError("Failed to parse bitcode file: " + + entry.path().string() + " Error: " + stream.str()); + } + std::unique_ptr<llvm::Module> ir_module = std::move(module_or_error.get()); + + ARROW_RETURN_IF(llvm::verifyModule(*ir_module, &llvm::errs()), + Status::CodeGenError("verify of IR Module failed")); Review Comment: Can we add a detail message to this too like the above code? ########## cpp/src/gandiva/engine.cc: ########## @@ -220,6 +224,42 @@ static void SetDataLayout(llvm::Module* module) { } // end of the mofified method from MLIR +// Loading extended IR files from the given directory +// all .bc files under the given directory will be loaded and parsed +Status Engine::LoadExtendedPreCompiledIR(const std::filesystem::path& dir_path) { + for (const auto& entry : std::filesystem::directory_iterator(dir_path)) { + if (entry.is_regular_file() && entry.path().extension() == ".bc") { + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> buffer_or_error = + llvm::MemoryBuffer::getFile(entry.path().string()); + + ARROW_RETURN_IF(!buffer_or_error, + Status::CodeGenError("Could not load module from IR file: ", + entry.path().string() + " Error: " + + buffer_or_error.getError().message())); + + std::unique_ptr<llvm::MemoryBuffer> buffer = std::move(buffer_or_error.get()); Review Comment: ```suggestion auto buffer = std::move(buffer_or_error.get()); ``` ########## cpp/src/gandiva/extension_tests/multiple_registries/reg_1.json: ########## Review Comment: How about renaming them to `registry_1.json` and `registry_2.json`? ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> Review Comment: ```suggestion #include <filesystem> #include <fstream> #include <unordered_map> #include <vector> #include <arrow/type.h> #include <gandiva/function_registry_external.h> #include <rapidjson/document.h> ``` ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), Review Comment: ```suggestion return Status::Invalid("JSON parse error (offset ", doc.GetErrorOffset(), ``` ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); Review Comment: ```suggestion return Status::TypeError("Not a JSON object"); ``` ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; + if (!functions.IsArray()) { + return Status::TypeError("'functions' property is expected to be a JSON array"); + } + + std::vector<NativeFunction> funcs; + for (const auto& func : functions.GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto name, GetString(func, "name")); + ARROW_ASSIGN_OR_RAISE(auto aliases, GetAliases(func)); + ARROW_ASSIGN_OR_RAISE(DataTypeVector param_types, ParseParamTypes(func)); + ARROW_ASSIGN_OR_RAISE(auto ret_type, ParseDataType(func["return_type"])); + ARROW_ASSIGN_OR_RAISE(ResultNullableType result_nullable_type, + ParseResultNullable(func)); + ARROW_ASSIGN_OR_RAISE(auto pc_name, GetString(func, "pc_name")); Review Comment: Is "pc" "pre compiled"? Can we use more descriptive name for it such as `pre_compiled_name` and `native_name`? Ah, Gandiva already uses "pc" like `InitPCMap`... Hmm... ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; Review Comment: Can we use `auto`? ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; + if (!functions.IsArray()) { + return Status::TypeError("'functions' property is expected to be a JSON array"); + } + + std::vector<NativeFunction> funcs; + for (const auto& func : functions.GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto name, GetString(func, "name")); + ARROW_ASSIGN_OR_RAISE(auto aliases, GetAliases(func)); + ARROW_ASSIGN_OR_RAISE(DataTypeVector param_types, ParseParamTypes(func)); Review Comment: `auto`? ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; + if (!functions.IsArray()) { + return Status::TypeError("'functions' property is expected to be a JSON array"); + } + + std::vector<NativeFunction> funcs; + for (const auto& func : functions.GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto name, GetString(func, "name")); + ARROW_ASSIGN_OR_RAISE(auto aliases, GetAliases(func)); + ARROW_ASSIGN_OR_RAISE(DataTypeVector param_types, ParseParamTypes(func)); + ARROW_ASSIGN_OR_RAISE(auto ret_type, ParseDataType(func["return_type"])); + ARROW_ASSIGN_OR_RAISE(ResultNullableType result_nullable_type, + ParseResultNullable(func)); + ARROW_ASSIGN_OR_RAISE(auto pc_name, GetString(func, "pc_name")); + int32_t flags = GetFlags(func); Review Comment: `auto`? ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; + if (!functions.IsArray()) { + return Status::TypeError("'functions' property is expected to be a JSON array"); + } + + std::vector<NativeFunction> funcs; + for (const auto& func : functions.GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto name, GetString(func, "name")); + ARROW_ASSIGN_OR_RAISE(auto aliases, GetAliases(func)); + ARROW_ASSIGN_OR_RAISE(DataTypeVector param_types, ParseParamTypes(func)); + ARROW_ASSIGN_OR_RAISE(auto ret_type, ParseDataType(func["return_type"])); + ARROW_ASSIGN_OR_RAISE(ResultNullableType result_nullable_type, Review Comment: `auto`? ########## cpp/src/gandiva/function_registry.cc: ########## @@ -45,6 +47,21 @@ std::vector<NativeFunction> FunctionRegistry::pc_registry_; SignatureMap FunctionRegistry::pc_registry_map_ = InitPCMap(); +std::vector<NativeFunction> LoadExternalFunctionRegistry() { + std::string ext_dir; + const char* ext_dir_env = std::getenv("GANDIVA_EXTENSION_DIR"); Review Comment: Can we unify codes that refers `GANDIVA_EXTENSION_DIR`? ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; + if (!functions.IsArray()) { + return Status::TypeError("'functions' property is expected to be a JSON array"); + } + + std::vector<NativeFunction> funcs; + for (const auto& func : functions.GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto name, GetString(func, "name")); + ARROW_ASSIGN_OR_RAISE(auto aliases, GetAliases(func)); + ARROW_ASSIGN_OR_RAISE(DataTypeVector param_types, ParseParamTypes(func)); + ARROW_ASSIGN_OR_RAISE(auto ret_type, ParseDataType(func["return_type"])); + ARROW_ASSIGN_OR_RAISE(ResultNullableType result_nullable_type, + ParseResultNullable(func)); + ARROW_ASSIGN_OR_RAISE(auto pc_name, GetString(func, "pc_name")); + int32_t flags = GetFlags(func); + funcs.emplace_back(name, aliases, param_types, ret_type, result_nullable_type, + pc_name, flags); + } + return funcs; + } + + private: + static arrow::Result<std::string> GetString(const rj::GenericValue<rj::UTF8<>>& func, + const std::string& key) { + if (!func.HasMember(key.c_str())) { + return Status::TypeError("'" + key + "'" + " property is missing"); + } + if (!func[key.c_str()].IsString()) { + return Status::TypeError("'" + key + "'" + " property should be a string"); + } + return func[key.c_str()].GetString(); + } + + static arrow::Result<ResultNullableType> ParseResultNullable( + const rj::GenericValue<rj::UTF8<>>& func) { + std::string nullable; + if (!func.HasMember("result_nullable")) { + nullable = "ifnull"; + } else { + if (!func["result_nullable"].IsString()) { + return Status::TypeError("result_nullable property should be a string"); + } + nullable = func["result_nullable"].GetString(); + } + if (nullable == "ifnull") { + return ResultNullableType::kResultNullIfNull; + } else if (nullable == "never") { + return ResultNullableType::kResultNullNever; + } else if (nullable == "internal") { + return ResultNullableType::kResultNullInternal; + } else { + return Status::TypeError("Unsupported result_nullable value: " + nullable + + ". Only ifnull/never/internal are supported"); + } + } + static int32_t GetFlags(const rj::GenericValue<rj::UTF8<>>& func) { + int32_t flags = 0; + for (auto const& [flag_name, flag_value] : + {std::make_pair("needs_context", NativeFunction::kNeedsContext), + std::make_pair("needs_function_holder", NativeFunction::kNeedsFunctionHolder), + std::make_pair("can_return_errors", NativeFunction::kCanReturnErrors)}) { + if (func.HasMember(flag_name) && func[flag_name].GetBool()) { + flags |= flag_value; + } + } + return flags; + } + + static arrow::Result<std::vector<std::string>> GetAliases( + const rj::GenericValue<rj::UTF8<>>& func) { + std::vector<std::string> aliases; + if (!func.HasMember("aliases")) { + return aliases; + } + if (func["aliases"].IsArray()) { + for (const auto& alias : func["aliases"].GetArray()) { + aliases.emplace_back(alias.GetString()); + } + } else { + return Status::TypeError("'aliases' property is expected to be a JSON array"); + } + return aliases; + } + + static arrow::Result<arrow::DataTypeVector> ParseParamTypes( + const rj::GenericValue<rj::UTF8<>>& func) { + arrow::DataTypeVector param_types; + if (!func.HasMember("param_types")) { + return param_types; + } + if (!func["param_types"].IsArray()) { + return Status::TypeError("'param_types' property is expected to be a JSON array"); + } + for (const auto& param_type : func["param_types"].GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto type, ParseDataType(param_type)) + param_types.push_back(type); + } + return param_types; + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseTimestampDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("unit")) { + return Status::TypeError("'unit' property is required for timestamp data type"); + } + const std::string unit_name = data_type["unit"].GetString(); + arrow::TimeUnit::type unit; + if (unit_name == "second") { + unit = arrow::TimeUnit::SECOND; + } else if (unit_name == "milli") { + unit = arrow::TimeUnit::MILLI; + } else if (unit_name == "micro") { + unit = arrow::TimeUnit::MICRO; + } else if (unit_name == "nano") { + unit = arrow::TimeUnit::NANO; + } else { + return Status::TypeError("Unsupported timestamp unit name: ", unit_name); + } + return arrow::timestamp(unit); + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseDecimalDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("precision") || !data_type["precision"].IsInt()) { + return Status::TypeError( + "'precision' property is required for decimal data type and should be an " + "integer"); + } + if (!data_type.HasMember("scale") || !data_type["scale"].IsInt()) { + return Status::TypeError( + "'scale' property is required for decimal data type and should be an integer"); + } + auto precision = data_type["precision"].GetInt(); + auto scale = data_type["scale"].GetInt(); + const std::string type_name = data_type["type"].GetString(); + if (type_name == "decimal128") { + return arrow::decimal128(precision, scale); + } else if (type_name == "decimal256") { + return arrow::decimal256(precision, scale); + } + return arrow::decimal(precision, scale); + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseListDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("value_type") || !data_type["value_type"].IsObject()) { + return Status::TypeError( + "'value_type' property is required for list data type and should be an object"); Review Comment: Can we add `data_type` content to this error message? ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; + if (!functions.IsArray()) { + return Status::TypeError("'functions' property is expected to be a JSON array"); + } + + std::vector<NativeFunction> funcs; + for (const auto& func : functions.GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto name, GetString(func, "name")); + ARROW_ASSIGN_OR_RAISE(auto aliases, GetAliases(func)); + ARROW_ASSIGN_OR_RAISE(DataTypeVector param_types, ParseParamTypes(func)); + ARROW_ASSIGN_OR_RAISE(auto ret_type, ParseDataType(func["return_type"])); + ARROW_ASSIGN_OR_RAISE(ResultNullableType result_nullable_type, + ParseResultNullable(func)); + ARROW_ASSIGN_OR_RAISE(auto pc_name, GetString(func, "pc_name")); + int32_t flags = GetFlags(func); + funcs.emplace_back(name, aliases, param_types, ret_type, result_nullable_type, + pc_name, flags); + } + return funcs; + } + + private: + static arrow::Result<std::string> GetString(const rj::GenericValue<rj::UTF8<>>& func, + const std::string& key) { + if (!func.HasMember(key.c_str())) { + return Status::TypeError("'" + key + "'" + " property is missing"); + } + if (!func[key.c_str()].IsString()) { + return Status::TypeError("'" + key + "'" + " property should be a string"); + } + return func[key.c_str()].GetString(); + } + + static arrow::Result<ResultNullableType> ParseResultNullable( + const rj::GenericValue<rj::UTF8<>>& func) { + std::string nullable; + if (!func.HasMember("result_nullable")) { + nullable = "ifnull"; + } else { + if (!func["result_nullable"].IsString()) { + return Status::TypeError("result_nullable property should be a string"); Review Comment: How about adding the real value to the error message? ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; + if (!functions.IsArray()) { + return Status::TypeError("'functions' property is expected to be a JSON array"); + } + + std::vector<NativeFunction> funcs; + for (const auto& func : functions.GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto name, GetString(func, "name")); + ARROW_ASSIGN_OR_RAISE(auto aliases, GetAliases(func)); + ARROW_ASSIGN_OR_RAISE(DataTypeVector param_types, ParseParamTypes(func)); + ARROW_ASSIGN_OR_RAISE(auto ret_type, ParseDataType(func["return_type"])); + ARROW_ASSIGN_OR_RAISE(ResultNullableType result_nullable_type, + ParseResultNullable(func)); + ARROW_ASSIGN_OR_RAISE(auto pc_name, GetString(func, "pc_name")); + int32_t flags = GetFlags(func); + funcs.emplace_back(name, aliases, param_types, ret_type, result_nullable_type, + pc_name, flags); + } + return funcs; + } + + private: + static arrow::Result<std::string> GetString(const rj::GenericValue<rj::UTF8<>>& func, + const std::string& key) { + if (!func.HasMember(key.c_str())) { + return Status::TypeError("'" + key + "'" + " property is missing"); + } + if (!func[key.c_str()].IsString()) { + return Status::TypeError("'" + key + "'" + " property should be a string"); + } + return func[key.c_str()].GetString(); + } + + static arrow::Result<ResultNullableType> ParseResultNullable( + const rj::GenericValue<rj::UTF8<>>& func) { + std::string nullable; + if (!func.HasMember("result_nullable")) { + nullable = "ifnull"; + } else { + if (!func["result_nullable"].IsString()) { + return Status::TypeError("result_nullable property should be a string"); + } + nullable = func["result_nullable"].GetString(); + } + if (nullable == "ifnull") { + return ResultNullableType::kResultNullIfNull; + } else if (nullable == "never") { + return ResultNullableType::kResultNullNever; + } else if (nullable == "internal") { + return ResultNullableType::kResultNullInternal; + } else { + return Status::TypeError("Unsupported result_nullable value: " + nullable + + ". Only ifnull/never/internal are supported"); + } + } + static int32_t GetFlags(const rj::GenericValue<rj::UTF8<>>& func) { + int32_t flags = 0; + for (auto const& [flag_name, flag_value] : + {std::make_pair("needs_context", NativeFunction::kNeedsContext), + std::make_pair("needs_function_holder", NativeFunction::kNeedsFunctionHolder), + std::make_pair("can_return_errors", NativeFunction::kCanReturnErrors)}) { + if (func.HasMember(flag_name) && func[flag_name].GetBool()) { + flags |= flag_value; + } + } + return flags; + } + + static arrow::Result<std::vector<std::string>> GetAliases( + const rj::GenericValue<rj::UTF8<>>& func) { + std::vector<std::string> aliases; + if (!func.HasMember("aliases")) { + return aliases; + } + if (func["aliases"].IsArray()) { + for (const auto& alias : func["aliases"].GetArray()) { + aliases.emplace_back(alias.GetString()); + } + } else { + return Status::TypeError("'aliases' property is expected to be a JSON array"); + } + return aliases; + } + + static arrow::Result<arrow::DataTypeVector> ParseParamTypes( + const rj::GenericValue<rj::UTF8<>>& func) { + arrow::DataTypeVector param_types; + if (!func.HasMember("param_types")) { + return param_types; + } + if (!func["param_types"].IsArray()) { + return Status::TypeError("'param_types' property is expected to be a JSON array"); + } + for (const auto& param_type : func["param_types"].GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto type, ParseDataType(param_type)) + param_types.push_back(type); + } + return param_types; + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseTimestampDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("unit")) { + return Status::TypeError("'unit' property is required for timestamp data type"); + } + const std::string unit_name = data_type["unit"].GetString(); Review Comment: `auto`? ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; + if (!functions.IsArray()) { + return Status::TypeError("'functions' property is expected to be a JSON array"); + } + + std::vector<NativeFunction> funcs; + for (const auto& func : functions.GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto name, GetString(func, "name")); + ARROW_ASSIGN_OR_RAISE(auto aliases, GetAliases(func)); + ARROW_ASSIGN_OR_RAISE(DataTypeVector param_types, ParseParamTypes(func)); + ARROW_ASSIGN_OR_RAISE(auto ret_type, ParseDataType(func["return_type"])); + ARROW_ASSIGN_OR_RAISE(ResultNullableType result_nullable_type, + ParseResultNullable(func)); + ARROW_ASSIGN_OR_RAISE(auto pc_name, GetString(func, "pc_name")); + int32_t flags = GetFlags(func); + funcs.emplace_back(name, aliases, param_types, ret_type, result_nullable_type, + pc_name, flags); + } + return funcs; + } + + private: + static arrow::Result<std::string> GetString(const rj::GenericValue<rj::UTF8<>>& func, + const std::string& key) { + if (!func.HasMember(key.c_str())) { + return Status::TypeError("'" + key + "'" + " property is missing"); + } + if (!func[key.c_str()].IsString()) { + return Status::TypeError("'" + key + "'" + " property should be a string"); + } + return func[key.c_str()].GetString(); + } + + static arrow::Result<ResultNullableType> ParseResultNullable( + const rj::GenericValue<rj::UTF8<>>& func) { + std::string nullable; + if (!func.HasMember("result_nullable")) { + nullable = "ifnull"; + } else { + if (!func["result_nullable"].IsString()) { + return Status::TypeError("result_nullable property should be a string"); + } + nullable = func["result_nullable"].GetString(); + } + if (nullable == "ifnull") { + return ResultNullableType::kResultNullIfNull; + } else if (nullable == "never") { + return ResultNullableType::kResultNullNever; + } else if (nullable == "internal") { + return ResultNullableType::kResultNullInternal; + } else { + return Status::TypeError("Unsupported result_nullable value: " + nullable + + ". Only ifnull/never/internal are supported"); + } + } + static int32_t GetFlags(const rj::GenericValue<rj::UTF8<>>& func) { + int32_t flags = 0; + for (auto const& [flag_name, flag_value] : + {std::make_pair("needs_context", NativeFunction::kNeedsContext), + std::make_pair("needs_function_holder", NativeFunction::kNeedsFunctionHolder), + std::make_pair("can_return_errors", NativeFunction::kCanReturnErrors)}) { + if (func.HasMember(flag_name) && func[flag_name].GetBool()) { + flags |= flag_value; + } + } + return flags; + } + + static arrow::Result<std::vector<std::string>> GetAliases( + const rj::GenericValue<rj::UTF8<>>& func) { + std::vector<std::string> aliases; + if (!func.HasMember("aliases")) { + return aliases; + } + if (func["aliases"].IsArray()) { + for (const auto& alias : func["aliases"].GetArray()) { + aliases.emplace_back(alias.GetString()); + } + } else { + return Status::TypeError("'aliases' property is expected to be a JSON array"); + } + return aliases; + } + + static arrow::Result<arrow::DataTypeVector> ParseParamTypes( + const rj::GenericValue<rj::UTF8<>>& func) { + arrow::DataTypeVector param_types; + if (!func.HasMember("param_types")) { + return param_types; + } + if (!func["param_types"].IsArray()) { + return Status::TypeError("'param_types' property is expected to be a JSON array"); + } + for (const auto& param_type : func["param_types"].GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto type, ParseDataType(param_type)) + param_types.push_back(type); + } + return param_types; + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseTimestampDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("unit")) { + return Status::TypeError("'unit' property is required for timestamp data type"); + } + const std::string unit_name = data_type["unit"].GetString(); + arrow::TimeUnit::type unit; + if (unit_name == "second") { + unit = arrow::TimeUnit::SECOND; + } else if (unit_name == "milli") { + unit = arrow::TimeUnit::MILLI; + } else if (unit_name == "micro") { + unit = arrow::TimeUnit::MICRO; + } else if (unit_name == "nano") { + unit = arrow::TimeUnit::NANO; + } else { + return Status::TypeError("Unsupported timestamp unit name: ", unit_name); + } + return arrow::timestamp(unit); + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseDecimalDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("precision") || !data_type["precision"].IsInt()) { + return Status::TypeError( + "'precision' property is required for decimal data type and should be an " + "integer"); + } + if (!data_type.HasMember("scale") || !data_type["scale"].IsInt()) { + return Status::TypeError( + "'scale' property is required for decimal data type and should be an integer"); + } + auto precision = data_type["precision"].GetInt(); + auto scale = data_type["scale"].GetInt(); + const std::string type_name = data_type["type"].GetString(); Review Comment: `auto`? ########## cpp/src/gandiva/function_registry_external_test.cc: ########## @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry_external.h" +#include <gtest/gtest.h> +#include <filesystem> +#include "arrow/testing/gtest_util.h" +#include "gandiva/tests/test_util.h" + +namespace gandiva { +class TestExternalFunctionRegistry : public ::testing::Test { + public: + arrow::Result<std::vector<NativeFunction>> GetFuncs(const std::string& registry_dir) { + std::filesystem::path base(GANDIVA_EXTENSION_TEST_DIR); + return GetExternalFunctionRegistry((base / registry_dir).string()); + } +}; + +TEST_F(TestExternalFunctionRegistry, EmptyDir) { + ASSERT_OK_AND_ASSIGN(auto funcs, GetExternalFunctionRegistry("")); + ASSERT_TRUE(funcs.empty()); +} + +TEST_F(TestExternalFunctionRegistry, FunctionWithoutName) { + auto funcs = GetFuncs("no_name_func_registry"); + ASSERT_TRUE(!funcs.ok()); Review Comment: Could you use `ASSERT_RAISES_WITH_MESSAGE()` (or `ASSERT_RAISES()`)? ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; + if (!functions.IsArray()) { + return Status::TypeError("'functions' property is expected to be a JSON array"); + } + + std::vector<NativeFunction> funcs; + for (const auto& func : functions.GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto name, GetString(func, "name")); + ARROW_ASSIGN_OR_RAISE(auto aliases, GetAliases(func)); + ARROW_ASSIGN_OR_RAISE(DataTypeVector param_types, ParseParamTypes(func)); + ARROW_ASSIGN_OR_RAISE(auto ret_type, ParseDataType(func["return_type"])); + ARROW_ASSIGN_OR_RAISE(ResultNullableType result_nullable_type, + ParseResultNullable(func)); + ARROW_ASSIGN_OR_RAISE(auto pc_name, GetString(func, "pc_name")); + int32_t flags = GetFlags(func); + funcs.emplace_back(name, aliases, param_types, ret_type, result_nullable_type, + pc_name, flags); + } + return funcs; + } + + private: + static arrow::Result<std::string> GetString(const rj::GenericValue<rj::UTF8<>>& func, + const std::string& key) { + if (!func.HasMember(key.c_str())) { + return Status::TypeError("'" + key + "'" + " property is missing"); + } + if (!func[key.c_str()].IsString()) { + return Status::TypeError("'" + key + "'" + " property should be a string"); + } + return func[key.c_str()].GetString(); + } + + static arrow::Result<ResultNullableType> ParseResultNullable( + const rj::GenericValue<rj::UTF8<>>& func) { + std::string nullable; + if (!func.HasMember("result_nullable")) { + nullable = "ifnull"; + } else { + if (!func["result_nullable"].IsString()) { + return Status::TypeError("result_nullable property should be a string"); + } + nullable = func["result_nullable"].GetString(); + } + if (nullable == "ifnull") { + return ResultNullableType::kResultNullIfNull; + } else if (nullable == "never") { + return ResultNullableType::kResultNullNever; + } else if (nullable == "internal") { + return ResultNullableType::kResultNullInternal; + } else { + return Status::TypeError("Unsupported result_nullable value: " + nullable + + ". Only ifnull/never/internal are supported"); + } + } + static int32_t GetFlags(const rj::GenericValue<rj::UTF8<>>& func) { + int32_t flags = 0; + for (auto const& [flag_name, flag_value] : + {std::make_pair("needs_context", NativeFunction::kNeedsContext), + std::make_pair("needs_function_holder", NativeFunction::kNeedsFunctionHolder), + std::make_pair("can_return_errors", NativeFunction::kCanReturnErrors)}) { + if (func.HasMember(flag_name) && func[flag_name].GetBool()) { + flags |= flag_value; + } + } + return flags; + } + + static arrow::Result<std::vector<std::string>> GetAliases( + const rj::GenericValue<rj::UTF8<>>& func) { + std::vector<std::string> aliases; + if (!func.HasMember("aliases")) { + return aliases; + } + if (func["aliases"].IsArray()) { + for (const auto& alias : func["aliases"].GetArray()) { + aliases.emplace_back(alias.GetString()); + } + } else { + return Status::TypeError("'aliases' property is expected to be a JSON array"); + } + return aliases; + } + + static arrow::Result<arrow::DataTypeVector> ParseParamTypes( + const rj::GenericValue<rj::UTF8<>>& func) { + arrow::DataTypeVector param_types; + if (!func.HasMember("param_types")) { + return param_types; + } + if (!func["param_types"].IsArray()) { + return Status::TypeError("'param_types' property is expected to be a JSON array"); + } + for (const auto& param_type : func["param_types"].GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto type, ParseDataType(param_type)) + param_types.push_back(type); + } + return param_types; + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseTimestampDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("unit")) { + return Status::TypeError("'unit' property is required for timestamp data type"); + } + const std::string unit_name = data_type["unit"].GetString(); + arrow::TimeUnit::type unit; + if (unit_name == "second") { + unit = arrow::TimeUnit::SECOND; + } else if (unit_name == "milli") { + unit = arrow::TimeUnit::MILLI; + } else if (unit_name == "micro") { + unit = arrow::TimeUnit::MICRO; + } else if (unit_name == "nano") { + unit = arrow::TimeUnit::NANO; + } else { + return Status::TypeError("Unsupported timestamp unit name: ", unit_name); + } + return arrow::timestamp(unit); + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseDecimalDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("precision") || !data_type["precision"].IsInt()) { + return Status::TypeError( + "'precision' property is required for decimal data type and should be an " + "integer"); + } + if (!data_type.HasMember("scale") || !data_type["scale"].IsInt()) { + return Status::TypeError( + "'scale' property is required for decimal data type and should be an integer"); + } + auto precision = data_type["precision"].GetInt(); + auto scale = data_type["scale"].GetInt(); + const std::string type_name = data_type["type"].GetString(); + if (type_name == "decimal128") { + return arrow::decimal128(precision, scale); + } else if (type_name == "decimal256") { + return arrow::decimal256(precision, scale); + } + return arrow::decimal(precision, scale); + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseListDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("value_type") || !data_type["value_type"].IsObject()) { + return Status::TypeError( + "'value_type' property is required for list data type and should be an object"); + } + ARROW_ASSIGN_OR_RAISE(auto value_type, ParseDataType(data_type["value_type"])); + return arrow::list(value_type); + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseComplexDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + static const std::unordered_map< + std::string, std::function<arrow::Result<std::shared_ptr<arrow::DataType>>( + const rj::GenericValue<rj::UTF8<>>&)>> + complex_type_map = {{"timestamp", ParseTimestampDataType}, + {"decimal", ParseDecimalDataType}, + {"decimal128", ParseDecimalDataType}, + {"decimal256", ParseDecimalDataType}, + {"list", ParseListDataType}}; + const std::string type_name = data_type["type"].GetString(); + auto it = complex_type_map.find(type_name); + if (it == complex_type_map.end()) { + return Status::TypeError("Unsupported complex type name: ", type_name); + } + return it->second(data_type); + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("type")) { + return Status::TypeError("'type' property is required for data type"); + } + auto type_name = data_type["type"].GetString(); + auto type = ParseDataTypeFromName(type_name); + if (type == nullptr) { + return ParseComplexDataType(data_type); + } else { + return type; + } + } + + static std::shared_ptr<arrow::DataType> ParseDataTypeFromName( + const std::string& type_name) { + static const std::unordered_map<std::string, std::shared_ptr<arrow::DataType>> + simple_type_map = {{"null", arrow::null()}, + {"boolean", arrow::boolean()}, + {"uint8", arrow::uint8()}, + {"int8", arrow::int8()}, + {"uint16", arrow::uint16()}, + {"int16", arrow::int16()}, + {"uint32", arrow::uint32()}, + {"int32", arrow::int32()}, + {"uint64", arrow::uint64()}, + {"int64", arrow::int64()}, + {"float16", arrow::float16()}, + {"float32", arrow::float32()}, + {"float64", arrow::float64()}, + {"utf8", arrow::utf8()}, + {"large_utf8", arrow::large_utf8()}, + {"binary", arrow::binary()}, + {"large_binary", arrow::large_binary()}, + {"date32", arrow::date32()}, + {"date64", arrow::date64()}, + {"day_time_interval", arrow::day_time_interval()}, + {"month_interval", arrow::month_interval()}}; + + auto it = simple_type_map.find(type_name); + return it != simple_type_map.end() ? it->second : nullptr; + } +}; + +// iterate all files under registry_dir by file names +std::vector<std::filesystem::path> ListAllFiles(const std::string& registry_dir) { + if (registry_dir.empty()) { + return {}; + } + std::vector<std::filesystem::path> filenames; + for (const auto& entry : std::filesystem::directory_iterator(registry_dir)) { + filenames.push_back(entry.path()); + } + + std::sort(filenames.begin(), filenames.end()); + return filenames; +} + +arrow::Result<std::vector<NativeFunction>> GetExternalFunctionRegistry( + const std::string& registry_dir) { + std::vector<NativeFunction> registry; + auto filenames = ListAllFiles(registry_dir); + for (const auto& entry : filenames) { + if (entry.extension() == ".json") { + std::ifstream file(entry); + std::string content((std::istreambuf_iterator<char>(file)), + std::istreambuf_iterator<char>()); + + auto funcs_result = JsonRegistryParser::Parse(content); + if (!funcs_result.ok()) { + return funcs_result.status().WithMessage( + "Failed to parse json file: ", entry.string(), + ". Error: ", funcs_result.status().message()); + } + auto funcs = funcs_result.ValueUnsafe(); Review Comment: ```suggestion auto funcs = *funcs_result; ``` ########## cpp/src/gandiva/function_registry_external.cc: ########## @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <arrow/type.h> +#include <gandiva/function_registry_external.h> +#include <rapidjson/document.h> +#include <filesystem> +#include <fstream> +#include <unordered_map> +#include <vector> + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result<std::vector<NativeFunction>> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast<const rj::Document::Ch*>(json.data()), + static_cast<size_t>(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; + if (!functions.IsArray()) { + return Status::TypeError("'functions' property is expected to be a JSON array"); + } + + std::vector<NativeFunction> funcs; + for (const auto& func : functions.GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto name, GetString(func, "name")); + ARROW_ASSIGN_OR_RAISE(auto aliases, GetAliases(func)); + ARROW_ASSIGN_OR_RAISE(DataTypeVector param_types, ParseParamTypes(func)); + ARROW_ASSIGN_OR_RAISE(auto ret_type, ParseDataType(func["return_type"])); + ARROW_ASSIGN_OR_RAISE(ResultNullableType result_nullable_type, + ParseResultNullable(func)); + ARROW_ASSIGN_OR_RAISE(auto pc_name, GetString(func, "pc_name")); + int32_t flags = GetFlags(func); + funcs.emplace_back(name, aliases, param_types, ret_type, result_nullable_type, + pc_name, flags); + } + return funcs; + } + + private: + static arrow::Result<std::string> GetString(const rj::GenericValue<rj::UTF8<>>& func, + const std::string& key) { + if (!func.HasMember(key.c_str())) { + return Status::TypeError("'" + key + "'" + " property is missing"); + } + if (!func[key.c_str()].IsString()) { + return Status::TypeError("'" + key + "'" + " property should be a string"); + } + return func[key.c_str()].GetString(); + } + + static arrow::Result<ResultNullableType> ParseResultNullable( + const rj::GenericValue<rj::UTF8<>>& func) { + std::string nullable; + if (!func.HasMember("result_nullable")) { + nullable = "ifnull"; + } else { + if (!func["result_nullable"].IsString()) { + return Status::TypeError("result_nullable property should be a string"); + } + nullable = func["result_nullable"].GetString(); + } + if (nullable == "ifnull") { + return ResultNullableType::kResultNullIfNull; + } else if (nullable == "never") { + return ResultNullableType::kResultNullNever; + } else if (nullable == "internal") { + return ResultNullableType::kResultNullInternal; + } else { + return Status::TypeError("Unsupported result_nullable value: " + nullable + + ". Only ifnull/never/internal are supported"); + } + } + static int32_t GetFlags(const rj::GenericValue<rj::UTF8<>>& func) { + int32_t flags = 0; + for (auto const& [flag_name, flag_value] : + {std::make_pair("needs_context", NativeFunction::kNeedsContext), + std::make_pair("needs_function_holder", NativeFunction::kNeedsFunctionHolder), + std::make_pair("can_return_errors", NativeFunction::kCanReturnErrors)}) { + if (func.HasMember(flag_name) && func[flag_name].GetBool()) { + flags |= flag_value; + } + } + return flags; + } + + static arrow::Result<std::vector<std::string>> GetAliases( + const rj::GenericValue<rj::UTF8<>>& func) { + std::vector<std::string> aliases; + if (!func.HasMember("aliases")) { + return aliases; + } + if (func["aliases"].IsArray()) { + for (const auto& alias : func["aliases"].GetArray()) { + aliases.emplace_back(alias.GetString()); + } + } else { + return Status::TypeError("'aliases' property is expected to be a JSON array"); + } + return aliases; + } + + static arrow::Result<arrow::DataTypeVector> ParseParamTypes( + const rj::GenericValue<rj::UTF8<>>& func) { + arrow::DataTypeVector param_types; + if (!func.HasMember("param_types")) { + return param_types; + } + if (!func["param_types"].IsArray()) { + return Status::TypeError("'param_types' property is expected to be a JSON array"); + } + for (const auto& param_type : func["param_types"].GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto type, ParseDataType(param_type)) + param_types.push_back(type); + } + return param_types; + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseTimestampDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("unit")) { + return Status::TypeError("'unit' property is required for timestamp data type"); + } + const std::string unit_name = data_type["unit"].GetString(); + arrow::TimeUnit::type unit; + if (unit_name == "second") { + unit = arrow::TimeUnit::SECOND; + } else if (unit_name == "milli") { + unit = arrow::TimeUnit::MILLI; + } else if (unit_name == "micro") { + unit = arrow::TimeUnit::MICRO; + } else if (unit_name == "nano") { + unit = arrow::TimeUnit::NANO; + } else { + return Status::TypeError("Unsupported timestamp unit name: ", unit_name); + } + return arrow::timestamp(unit); + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseDecimalDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("precision") || !data_type["precision"].IsInt()) { + return Status::TypeError( + "'precision' property is required for decimal data type and should be an " + "integer"); + } + if (!data_type.HasMember("scale") || !data_type["scale"].IsInt()) { + return Status::TypeError( + "'scale' property is required for decimal data type and should be an integer"); + } + auto precision = data_type["precision"].GetInt(); + auto scale = data_type["scale"].GetInt(); + const std::string type_name = data_type["type"].GetString(); + if (type_name == "decimal128") { + return arrow::decimal128(precision, scale); + } else if (type_name == "decimal256") { + return arrow::decimal256(precision, scale); + } + return arrow::decimal(precision, scale); + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseListDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("value_type") || !data_type["value_type"].IsObject()) { + return Status::TypeError( + "'value_type' property is required for list data type and should be an object"); + } + ARROW_ASSIGN_OR_RAISE(auto value_type, ParseDataType(data_type["value_type"])); + return arrow::list(value_type); + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseComplexDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + static const std::unordered_map< + std::string, std::function<arrow::Result<std::shared_ptr<arrow::DataType>>( + const rj::GenericValue<rj::UTF8<>>&)>> + complex_type_map = {{"timestamp", ParseTimestampDataType}, + {"decimal", ParseDecimalDataType}, + {"decimal128", ParseDecimalDataType}, + {"decimal256", ParseDecimalDataType}, + {"list", ParseListDataType}}; + const std::string type_name = data_type["type"].GetString(); + auto it = complex_type_map.find(type_name); + if (it == complex_type_map.end()) { + return Status::TypeError("Unsupported complex type name: ", type_name); + } + return it->second(data_type); + } + + static arrow::Result<std::shared_ptr<arrow::DataType>> ParseDataType( + const rj::GenericValue<rj::UTF8<>>& data_type) { + if (!data_type.HasMember("type")) { + return Status::TypeError("'type' property is required for data type"); + } + auto type_name = data_type["type"].GetString(); + auto type = ParseDataTypeFromName(type_name); + if (type == nullptr) { + return ParseComplexDataType(data_type); + } else { + return type; + } + } + + static std::shared_ptr<arrow::DataType> ParseDataTypeFromName( + const std::string& type_name) { + static const std::unordered_map<std::string, std::shared_ptr<arrow::DataType>> + simple_type_map = {{"null", arrow::null()}, + {"boolean", arrow::boolean()}, + {"uint8", arrow::uint8()}, + {"int8", arrow::int8()}, + {"uint16", arrow::uint16()}, + {"int16", arrow::int16()}, + {"uint32", arrow::uint32()}, + {"int32", arrow::int32()}, + {"uint64", arrow::uint64()}, + {"int64", arrow::int64()}, + {"float16", arrow::float16()}, + {"float32", arrow::float32()}, + {"float64", arrow::float64()}, + {"utf8", arrow::utf8()}, + {"large_utf8", arrow::large_utf8()}, + {"binary", arrow::binary()}, + {"large_binary", arrow::large_binary()}, + {"date32", arrow::date32()}, + {"date64", arrow::date64()}, + {"day_time_interval", arrow::day_time_interval()}, + {"month_interval", arrow::month_interval()}}; + + auto it = simple_type_map.find(type_name); + return it != simple_type_map.end() ? it->second : nullptr; + } +}; + +// iterate all files under registry_dir by file names +std::vector<std::filesystem::path> ListAllFiles(const std::string& registry_dir) { + if (registry_dir.empty()) { + return {}; + } + std::vector<std::filesystem::path> filenames; + for (const auto& entry : std::filesystem::directory_iterator(registry_dir)) { + filenames.push_back(entry.path()); + } + + std::sort(filenames.begin(), filenames.end()); + return filenames; +} + +arrow::Result<std::vector<NativeFunction>> GetExternalFunctionRegistry( + const std::string& registry_dir) { + std::vector<NativeFunction> registry; + auto filenames = ListAllFiles(registry_dir); + for (const auto& entry : filenames) { + if (entry.extension() == ".json") { Review Comment: Do we need `entry.is_regular_file()` check? ########## cpp/src/gandiva/function_registry_external.h: ########## @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <arrow/result.h> +#include <string> +#include <vector> +#include "gandiva/native_function.h" Review Comment: ```suggestion #include <string> #include <vector> #include <arrow/result.h> #include "gandiva/native_function.h" ``` ########## cpp/src/gandiva/precompiled/CMakeLists.txt: ########## @@ -144,5 +85,6 @@ if(ARROW_BUILD_TESTS) set_property(TEST gandiva-precompiled-test APPEND PROPERTY LABELS "unittest;gandiva-tests") + Review Comment: ```suggestion ``` ########## cpp/src/gandiva/function_registry_external_test.cc: ########## @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry_external.h" +#include <gtest/gtest.h> +#include <filesystem> +#include "arrow/testing/gtest_util.h" +#include "gandiva/tests/test_util.h" Review Comment: ```suggestion #include <filesystem> #include <gtest/gtest.h> #include "arrow/testing/gtest_util.h" #include "gandiva/function_registry_external.h" #include "gandiva/tests/test_util.h" ``` ########## cpp/src/gandiva/function_registry_external_test.cc: ########## @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry_external.h" +#include <gtest/gtest.h> +#include <filesystem> +#include "arrow/testing/gtest_util.h" +#include "gandiva/tests/test_util.h" + +namespace gandiva { +class TestExternalFunctionRegistry : public ::testing::Test { + public: + arrow::Result<std::vector<NativeFunction>> GetFuncs(const std::string& registry_dir) { + std::filesystem::path base(GANDIVA_EXTENSION_TEST_DIR); + return GetExternalFunctionRegistry((base / registry_dir).string()); + } +}; + +TEST_F(TestExternalFunctionRegistry, EmptyDir) { + ASSERT_OK_AND_ASSIGN(auto funcs, GetExternalFunctionRegistry("")); + ASSERT_TRUE(funcs.empty()); +} + +TEST_F(TestExternalFunctionRegistry, FunctionWithoutName) { + auto funcs = GetFuncs("no_name_func_registry"); + ASSERT_TRUE(!funcs.ok()); +} + +TEST_F(TestExternalFunctionRegistry, DirWithJsonRegistry) { + ASSERT_OK_AND_ASSIGN(auto funcs, GetFuncs("simple_registry")); + ASSERT_EQ(funcs.size(), 1); + ASSERT_EQ(funcs[0].result_nullable_type(), ResultNullableType::kResultNullNever); + ASSERT_EQ(funcs[0].CanReturnErrors(), true); + ASSERT_EQ(funcs[0].pc_name(), "say_hello_utf8"); +} + +TEST_F(TestExternalFunctionRegistry, DirWithMultiJsonRegistry) { + ASSERT_OK_AND_ASSIGN(auto funcs, GetFuncs("multiple_registries")); + ASSERT_EQ(funcs.size(), 2); + auto sigs_0 = funcs[0].signatures(); + ASSERT_EQ(sigs_0.size(), 2); + ASSERT_EQ(sigs_0[0].param_types().size(), 1); + ASSERT_EQ(sigs_0[0].param_types()[0]->id(), arrow::Type::STRING); + ASSERT_EQ(sigs_0[0].ret_type()->id(), arrow::Type::INT64); Review Comment: Do we need to check `sigs_0[1]` too? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
