This is an automated email from the ASF dual-hosted git repository. wesm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
commit a23fd1b8a8342489f8566466649ff26891a18fb4 Author: praveenbingo <[email protected]> AuthorDate: Mon Sep 24 11:51:40 2018 +0530 [Gandiva] Error handling support. Added support for error handling and messaging in IR functions. --- cpp/src/gandiva/compiled_expr.h | 3 +- cpp/src/gandiva/eval_batch.h | 6 ++ cpp/src/gandiva/execution_context.cc | 36 +++++++ cpp/src/gandiva/execution_context.h | 40 +++++++ cpp/src/gandiva/function_registry.cc | 82 +++++++------- cpp/src/gandiva/jni/id_to_module_map.h | 7 -- cpp/src/gandiva/like_holder_test.cc | 8 +- cpp/src/gandiva/llvm_generator.cc | 86 +++++++++------ cpp/src/gandiva/llvm_generator.h | 37 +++++++ cpp/src/gandiva/llvm_generator_test.cc | 4 +- cpp/src/gandiva/native_function.h | 14 +-- cpp/src/gandiva/precompiled/CMakeLists.txt | 6 +- cpp/src/gandiva/precompiled/arithmetic_ops.cc | 7 +- cpp/src/gandiva/precompiled/arithmetic_ops_test.cc | 14 +-- cpp/src/gandiva/precompiled/context_helper.h | 25 +++++ cpp/src/gandiva/precompiled/types.h | 2 +- cpp/src/gandiva/status.h | 10 ++ cpp/src/gandiva/tests/projector_test.cc | 28 +++-- cpp/src/gandiva/tests/to_string_test.cc | 10 +- .../arrow/gandiva/evaluator/ProjectorTest.java | 119 +++++++++++++++++++++ 20 files changed, 427 insertions(+), 117 deletions(-) diff --git a/cpp/src/gandiva/compiled_expr.h b/cpp/src/gandiva/compiled_expr.h index 142b311..db98842 100644 --- a/cpp/src/gandiva/compiled_expr.h +++ b/cpp/src/gandiva/compiled_expr.h @@ -23,7 +23,8 @@ namespace gandiva { -using EvalFunc = int (*)(uint8_t** buffers, uint8_t** local_bitmaps, int record_count); +using EvalFunc = int (*)(uint8_t **buffers, uint8_t **local_bitmaps, + int64_t execution_ctx_ptr, int record_count); /// \brief Tracks the compiled state for one expression. class CompiledExpr { diff --git a/cpp/src/gandiva/eval_batch.h b/cpp/src/gandiva/eval_batch.h index 977b455..c6501b9 100644 --- a/cpp/src/gandiva/eval_batch.h +++ b/cpp/src/gandiva/eval_batch.h @@ -25,6 +25,7 @@ #include "gandiva/arrow.h" #include "gandiva/gandiva_aliases.h" #include "gandiva/local_bitmaps_holder.h" +#include "gandiva/execution_context.h" namespace gandiva { @@ -38,6 +39,7 @@ class EvalBatch { buffers_array_.reset(new uint8_t*[num_buffers]); } local_bitmaps_holder_.reset(new LocalBitMapsHolder(num_records, num_local_bitmaps)); + execution_context_.reset(new ExecutionContext()); } int num_records() const { return num_records_; } @@ -69,6 +71,8 @@ class EvalBatch { return local_bitmaps_holder_->GetLocalBitMapArray(); } + ExecutionContext *GetExecutionContext() const { return execution_context_.get(); } + private: /// number of records in the current batch. int num_records_; @@ -82,6 +86,8 @@ class EvalBatch { std::unique_ptr<uint8_t*> buffers_array_; std::unique_ptr<LocalBitMapsHolder> local_bitmaps_holder_; + + std::unique_ptr<ExecutionContext> execution_context_; }; } // namespace gandiva diff --git a/cpp/src/gandiva/execution_context.cc b/cpp/src/gandiva/execution_context.cc new file mode 100644 index 0000000..e0a2887 --- /dev/null +++ b/cpp/src/gandiva/execution_context.cc @@ -0,0 +1,36 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "codegen/execution_context.h" + +namespace gandiva { +#ifdef GDV_HELPERS +namespace helpers { +#endif + +void ExecutionContext::set_error_msg(const char *error_msg) { + if (error_msg_.empty()) { + error_msg_ = std::string(error_msg); + } +} + +std::string ExecutionContext::get_error() const { return error_msg_; } + +bool ExecutionContext::has_error() const { return !error_msg_.empty(); } + +#ifdef GDV_HELPERS +} +#endif + +} // namespace gandiva diff --git a/cpp/src/gandiva/execution_context.h b/cpp/src/gandiva/execution_context.h new file mode 100644 index 0000000..85a82f4 --- /dev/null +++ b/cpp/src/gandiva/execution_context.h @@ -0,0 +1,40 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ERROR_HOLDER_H +#define ERROR_HOLDER_H + +#include <string> + +namespace gandiva { +#ifdef GDV_HELPERS +namespace helpers { +#endif +/// Error holder for errors during llvm module execution +class ExecutionContext { + public: + std::string get_error() const; + + void set_error_msg(const char* error_msg); + + bool has_error() const; + + private: + std::string error_msg_; +}; +#ifdef GDV_HELPERS +} +#endif +} // namespace gandiva +#endif // ERROR_HOLDER_H diff --git a/cpp/src/gandiva/function_registry.cc b/cpp/src/gandiva/function_registry.cc index d265222..033d369 100644 --- a/cpp/src/gandiva/function_registry.cc +++ b/cpp/src/gandiva/function_registry.cc @@ -45,28 +45,23 @@ using std::vector; // - NULL handling is of type NULL_IF_NULL // // The pre-compiled fn name includes the base name & input type names. eg. add_int32_int32 -#define BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(NAME, TYPE) \ - NativeFunction(#NAME, DataTypeVector{TYPE(), TYPE()}, TYPE(), true, \ - RESULT_NULL_IF_NULL, STRINGIFY(NAME##_##TYPE##_##TYPE)) +#define BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE(), TYPE()}, TYPE(), RESULT_NULL_IF_NULL, \ + STRINGIFY(NAME##_##TYPE##_##TYPE)) -// Binary functions that : -// - have the same input type for both params -// - output type is same as the input type -// - NULL handling is of type NULL_INTERNAL -// -// The pre-compiled fn name includes the base name & input type names. eg. -// divide_int64_int64 -#define BINARY_SYMMETRIC_NULL_INTERNAL(NAME, TYPE) \ - NativeFunction(#NAME, DataTypeVector{TYPE(), TYPE()}, TYPE(), true, \ - RESULT_NULL_INTERNAL, STRINGIFY(NAME##_##TYPE##_##TYPE)) +// Divide fubnction +#define DIVIDE(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE(), TYPE()}, TYPE(), RESULT_NULL_INTERNAL, \ + STRINGIFY(NAME##_##TYPE##_##TYPE), false /* does not need holder */, \ + true /* can return error */) // Binary functions that : // - have different input types, or output type // - NULL handling is of type NULL_IF_NULL // // The pre-compiled fn name includes the base name & input type names. eg. mod_int64_int32 -#define BINARY_GENERIC_SAFE_NULL_IF_NULL(NAME, IN_TYPE1, IN_TYPE2, OUT_TYPE) \ - NativeFunction(#NAME, DataTypeVector{IN_TYPE1(), IN_TYPE2()}, OUT_TYPE(), true, \ +#define BINARY_GENERIC_SAFE_NULL_IF_NULL(NAME, IN_TYPE1, IN_TYPE2, OUT_TYPE) \ + NativeFunction(#NAME, DataTypeVector{IN_TYPE1(), IN_TYPE2()}, OUT_TYPE(), \ RESULT_NULL_IF_NULL, STRINGIFY(NAME##_##IN_TYPE1##_##IN_TYPE2)) // Binary functions that : @@ -76,24 +71,24 @@ using std::vector; // // The pre-compiled fn name includes the base name & input type names. // eg. equal_int32_int32 -#define BINARY_RELATIONAL_SAFE_NULL_IF_NULL(NAME, TYPE) \ - NativeFunction(#NAME, DataTypeVector{TYPE(), TYPE()}, boolean(), true, \ - RESULT_NULL_IF_NULL, STRINGIFY(NAME##_##TYPE##_##TYPE)) +#define BINARY_RELATIONAL_SAFE_NULL_IF_NULL(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE(), TYPE()}, boolean(), RESULT_NULL_IF_NULL, \ + STRINGIFY(NAME##_##TYPE##_##TYPE)) // Unary functions that : // - NULL handling is of type NULL_IF_NULL // // The pre-compiled fn name includes the base name & input type name. eg. castFloat_int32 -#define UNARY_SAFE_NULL_IF_NULL(NAME, IN_TYPE, OUT_TYPE) \ - NativeFunction(#NAME, DataTypeVector{IN_TYPE()}, OUT_TYPE(), true, \ - RESULT_NULL_IF_NULL, STRINGIFY(NAME##_##IN_TYPE)) +#define UNARY_SAFE_NULL_IF_NULL(NAME, IN_TYPE, OUT_TYPE) \ + NativeFunction(#NAME, DataTypeVector{IN_TYPE()}, OUT_TYPE(), RESULT_NULL_IF_NULL, \ + STRINGIFY(NAME##_##IN_TYPE)) // Unary functions that : // - NULL handling is of type NULL_NEVER // // The pre-compiled fn name includes the base name & input type name. eg. isnull_int32 -#define UNARY_SAFE_NULL_NEVER_BOOL(NAME, TYPE) \ - NativeFunction(#NAME, DataTypeVector{TYPE()}, boolean(), true, RESULT_NULL_NEVER, \ +#define UNARY_SAFE_NULL_NEVER_BOOL(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE()}, boolean(), RESULT_NULL_NEVER, \ STRINGIFY(NAME##_##TYPE)) // Binary functions that : @@ -101,49 +96,49 @@ using std::vector; // // The pre-compiled fn name includes the base name & input type names, // eg. is_distinct_from_int32_int32 -#define BINARY_SAFE_NULL_NEVER_BOOL(NAME, TYPE) \ - NativeFunction(#NAME, DataTypeVector{TYPE(), TYPE()}, boolean(), true, \ - RESULT_NULL_NEVER, STRINGIFY(NAME##_##TYPE##_##TYPE)) +#define BINARY_SAFE_NULL_NEVER_BOOL(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE(), TYPE()}, boolean(), RESULT_NULL_NEVER, \ + STRINGIFY(NAME##_##TYPE##_##TYPE)) // Extract functions (used with data/time types) that : // - NULL handling is of type NULL_IF_NULL // // The pre-compiled fn name includes the base name & input type name. eg. extractYear_date -#define EXTRACT_SAFE_NULL_IF_NULL(NAME, TYPE) \ - NativeFunction(#NAME, DataTypeVector{TYPE()}, int64(), true, RESULT_NULL_IF_NULL, \ +#define EXTRACT_SAFE_NULL_IF_NULL(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE()}, int64(), RESULT_NULL_IF_NULL, \ STRINGIFY(NAME##_##TYPE)) // Hash32 functions that : // - NULL handling is of type NULL_NEVER // // The pre-compiled fn name includes the base name & input type name. hash32_int8 -#define HASH32_SAFE_NULL_NEVER(NAME, TYPE) \ - NativeFunction(#NAME, DataTypeVector{TYPE()}, int32(), true, RESULT_NULL_NEVER, \ +#define HASH32_SAFE_NULL_NEVER(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE()}, int32(), RESULT_NULL_NEVER, \ STRINGIFY(NAME##_##TYPE)) // Hash32 functions that : // - NULL handling is of type NULL_NEVER // // The pre-compiled fn name includes the base name & input type name. hash32_int8 -#define HASH64_SAFE_NULL_NEVER(NAME, TYPE) \ - NativeFunction(#NAME, DataTypeVector{TYPE()}, int64(), true, RESULT_NULL_NEVER, \ +#define HASH64_SAFE_NULL_NEVER(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE()}, int64(), RESULT_NULL_NEVER, \ STRINGIFY(NAME##_##TYPE)) // Hash32 functions with seed that : // - NULL handling is of type NULL_NEVER // // The pre-compiled fn name includes the base name & input type name. hash32WithSeed_int8 -#define HASH32_SEED_SAFE_NULL_NEVER(NAME, TYPE) \ - NativeFunction(#NAME, DataTypeVector{TYPE(), int32()}, int32(), true, \ - RESULT_NULL_NEVER, STRINGIFY(NAME##WithSeed_##TYPE)) +#define HASH32_SEED_SAFE_NULL_NEVER(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE(), int32()}, int32(), RESULT_NULL_NEVER, \ + STRINGIFY(NAME##WithSeed_##TYPE)) // Hash64 functions with seed that : // - NULL handling is of type NULL_NEVER // // The pre-compiled fn name includes the base name & input type name. hash32WithSeed_int8 -#define HASH64_SEED_SAFE_NULL_NEVER(NAME, TYPE) \ - NativeFunction(#NAME, DataTypeVector{TYPE(), int64()}, int64(), true, \ - RESULT_NULL_NEVER, STRINGIFY(NAME##WithSeed_##TYPE)) +#define HASH64_SEED_SAFE_NULL_NEVER(NAME, TYPE) \ + NativeFunction(#NAME, DataTypeVector{TYPE(), int64()}, int64(), RESULT_NULL_NEVER, \ + STRINGIFY(NAME##WithSeed_##TYPE)) // Iterate the inner macro over all numeric types #define NUMERIC_TYPES(INNER, NAME) \ @@ -178,7 +173,7 @@ NativeFunction FunctionRegistry::pc_registry_[] = { NUMERIC_TYPES(BINARY_SYMMETRIC_SAFE_NULL_IF_NULL, add), NUMERIC_TYPES(BINARY_SYMMETRIC_SAFE_NULL_IF_NULL, subtract), NUMERIC_TYPES(BINARY_SYMMETRIC_SAFE_NULL_IF_NULL, multiply), - NUMERIC_TYPES(BINARY_SYMMETRIC_NULL_INTERNAL, divide), + NUMERIC_TYPES(DIVIDE, divide), BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, int64, int32, int32), BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, int64, int64, int64), NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, equal), @@ -359,12 +354,13 @@ NativeFunction FunctionRegistry::pc_registry_[] = { BINARY_RELATIONAL_SAFE_NULL_IF_NULL(ends_with, utf8), BINARY_RELATIONAL_SAFE_NULL_IF_NULL(starts_with_plus_one, utf8), BINARY_RELATIONAL_SAFE_NULL_IF_NULL(ends_with_plus_one, utf8), - NativeFunction("like", DataTypeVector{utf8(), utf8()}, boolean(), true /*null_safe*/, - RESULT_NULL_IF_NULL, "like_utf8_utf8", true /*needs_holder*/), + + NativeFunction("like", DataTypeVector{utf8(), utf8()}, boolean(), RESULT_NULL_IF_NULL, + "like_utf8_utf8", true /*needs_holder*/), // Null internal (sample) - NativeFunction("half_or_null", DataTypeVector{int32()}, int32(), true /*null_safe*/, - RESULT_NULL_INTERNAL, "half_or_null_int32"), + NativeFunction("half_or_null", DataTypeVector{int32()}, int32(), RESULT_NULL_INTERNAL, + "half_or_null_int32"), }; // namespace gandiva FunctionRegistry::iterator FunctionRegistry::begin() const { diff --git a/cpp/src/gandiva/jni/id_to_module_map.h b/cpp/src/gandiva/jni/id_to_module_map.h index e722567..1f65887 100644 --- a/cpp/src/gandiva/jni/id_to_module_map.h +++ b/cpp/src/gandiva/jni/id_to_module_map.h @@ -46,13 +46,6 @@ class IdToModuleMap { HOLDER Lookup(jlong module_id) { HOLDER result = nullptr; - try { - result = map_.at(module_id); - } catch (const std::out_of_range&) { - } - if (result != nullptr) { - return result; - } mtx_.lock(); try { result = map_.at(module_id); diff --git a/cpp/src/gandiva/like_holder_test.cc b/cpp/src/gandiva/like_holder_test.cc index f3f5bae..97b384d 100644 --- a/cpp/src/gandiva/like_holder_test.cc +++ b/cpp/src/gandiva/like_holder_test.cc @@ -88,19 +88,19 @@ TEST_F(TestLikeHolder, TestOptimise) { // optimise for 'starts_with' auto fnode = LikeHolder::TryOptimize(BuildLike("xy 123z%")); EXPECT_EQ(fnode.descriptor()->name(), "starts_with"); - EXPECT_EQ(fnode.ToString(), "bool starts_with(utf8, (string) xy 123z)"); + EXPECT_EQ(fnode.ToString(), "bool starts_with((utf8) in, (const string) xy 123z)"); // optimise for 'ends_with' fnode = LikeHolder::TryOptimize(BuildLike("%xyz")); EXPECT_EQ(fnode.descriptor()->name(), "ends_with"); - EXPECT_EQ(fnode.ToString(), "bool ends_with(utf8, (string) xyz)"); + EXPECT_EQ(fnode.ToString(), "bool ends_with((utf8) in, (const string) xyz)"); // optimise for 'starts_with_plus_one fnode = LikeHolder::TryOptimize(BuildLike("xyz_")); - EXPECT_EQ(fnode.ToString(), "bool starts_with_plus_one(utf8, (string) xyz)"); + EXPECT_EQ(fnode.ToString(), "bool starts_with_plus_one((utf8) in, (const string) xyz)"); fnode = LikeHolder::TryOptimize(BuildLike("_xyz")); - EXPECT_EQ(fnode.ToString(), "bool ends_with_plus_one(utf8, (string) xyz)"); + EXPECT_EQ(fnode.ToString(), "bool ends_with_plus_one((utf8) in, (const string) xyz)"); // no optimisation for others. fnode = LikeHolder::TryOptimize(BuildLike("%xyz%")); diff --git a/cpp/src/gandiva/llvm_generator.cc b/cpp/src/gandiva/llvm_generator.cc index 351864e..2312b10 100644 --- a/cpp/src/gandiva/llvm_generator.cc +++ b/cpp/src/gandiva/llvm_generator.cc @@ -74,7 +74,7 @@ Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr out Status LLVMGenerator::Build(const ExpressionVector& exprs) { Status status; - for (auto& expr : exprs) { + for (auto &expr : exprs) { auto output = annotator_.AddOutputFieldDescriptor(expr->result()); status = Add(expr, output); GANDIVA_RETURN_NOT_OK(status); @@ -101,12 +101,15 @@ Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch, auto eval_batch = annotator_.PrepareEvalBatch(record_batch, output_vector); DCHECK_GT(eval_batch->GetNumBuffers(), 0); - for (auto& compiled_expr : compiled_exprs_) { + for (auto &compiled_expr : compiled_exprs_) { // generate data/offset vectors. EvalFunc jit_function = compiled_expr->jit_function(); jit_function(eval_batch->GetBufferArray(), eval_batch->GetLocalBitMapArray(), - static_cast<int>(record_batch.num_rows())); - + (int64_t)eval_batch->GetExecutionContext(), record_batch.num_rows()); + // check for execution errors + if (eval_batch->GetExecutionContext()->has_error()) { + return Status::ExecutionError(eval_batch->GetExecutionContext()->get_error()); + } // generate validity vectors. ComputeBitMapsForExpr(*compiled_expr, *eval_batch); } @@ -166,7 +169,8 @@ llvm::Value* LLVMGenerator::GetLocalBitMapReference(llvm::Value* arg_bitmaps, in // // The C-code equivalent is : // ------------------------------ -// int expr_0(int64_t *addrs, int64_t *local_bitmaps, int nrecords) { +// int expr_0(int64_t *addrs, int64_t *local_bitmaps, int 64_t execution_context_ptr, int +// nrecords) { // int *outVec = (int *) addrs[5]; // int *c0Vec = (int *) addrs[1]; // int *c1Vec = (int *) addrs[3]; @@ -181,8 +185,8 @@ llvm::Value* LLVMGenerator::GetLocalBitMapReference(llvm::Value* arg_bitmaps, in // IR Code // -------- // -// define i32 @expr_0(i64* %args, i64* %local_bitmaps, i32 %nrecords) { -// entry: +// define i32 @expr_0(i64* %args, i64* %local_bitmaps, i64 %execution_context_ptr, , i32 +// %nrecords) { entry: // %outmemAddr = getelementptr i64, i64* %args, i32 5 // %outmem = load i64, i64* %outmemAddr // %outVec = inttoptr i64 %outmem to i32* @@ -214,10 +218,11 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, FieldDescriptorPtr out llvm::IRBuilder<>& builder = ir_builder(); // Create fn prototype : - // int expr_1 (long **addrs, long **bitmaps, int nrec) - std::vector<llvm::Type*> arguments; + // int expr_1 (long **addrs, long **bitmaps, long *context_ptr, int nrec) + std::vector<llvm::Type *> arguments; arguments.push_back(types_->i64_ptr_type()); arguments.push_back(types_->i64_ptr_type()); + arguments.push_back(types_->i64_type()); arguments.push_back(types_->i32_type()); llvm::FunctionType* prototype = llvm::FunctionType::get(types_->i32_type(), arguments, false /*isVarArg*/); @@ -237,9 +242,11 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, FieldDescriptorPtr out llvm::Value* arg_local_bitmaps = &*args; arg_local_bitmaps->setName("local_bitmaps"); ++args; - llvm::Value* arg_nrecords = &*args; - arg_nrecords->setName("nrecords"); + llvm::Value *arg_context_ptr = &*args; + arg_context_ptr->setName("context_ptr"); ++args; + llvm::Value *arg_nrecords = &*args; + arg_nrecords->setName("nrecords"); llvm::BasicBlock* loop_entry = llvm::BasicBlock::Create(context(), "entry", *fn); llvm::BasicBlock* loop_body = llvm::BasicBlock::Create(context(), "loop", *fn); @@ -257,7 +264,8 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, FieldDescriptorPtr out llvm::PHINode* loop_var = builder.CreatePHI(types_->i32_type(), 2, "loop_var"); // The visitor can add code to both the entry/loop blocks. - Visitor visitor(this, *fn, loop_entry, arg_addrs, arg_local_bitmaps, loop_var); + Visitor visitor(this, *fn, loop_entry, arg_addrs, arg_local_bitmaps, arg_context_ptr, + loop_var); value_expr->Accept(visitor); LValuePtr output_value = visitor.result(); @@ -404,14 +412,16 @@ llvm::Value* LLVMGenerator::AddFunctionCall(const std::string& full_name, } // Visitor for generating the code for a decomposed expression. -LLVMGenerator::Visitor::Visitor(LLVMGenerator* generator, llvm::Function* function, - llvm::BasicBlock* entry_block, llvm::Value* arg_addrs, - llvm::Value* arg_local_bitmaps, llvm::Value* loop_var) +LLVMGenerator::Visitor::Visitor(LLVMGenerator *generator, llvm::Function *function, + llvm::BasicBlock *entry_block, llvm::Value *arg_addrs, + llvm::Value *arg_local_bitmaps, + llvm::Value *arg_context_ptr, llvm::Value *loop_var) : generator_(generator), function_(function), entry_block_(entry_block), arg_addrs_(arg_addrs), arg_local_bitmaps_(arg_local_bitmaps), + arg_context_ptr_(arg_context_ptr), loop_var_(loop_var) { ADD_VISITOR_TRACE("Iteration %T", loop_var); } @@ -580,11 +590,13 @@ void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) { dex.func_descriptor()->name()); LLVMTypes* types = generator_->types_.get(); + const NativeFunction *native_function = dex.native_function(); + // build the function params (ignore validity). - auto params = BuildParams(dex.function_holder().get(), dex.args(), false); + auto params = BuildParams(dex.function_holder().get(), dex.args(), false, + native_function->needs_context()); - const NativeFunction* native_function = dex.native_function(); - llvm::Type* ret_type = types->IRType(native_function->signature().ret_type()->id()); + llvm::Type *ret_type = types->IRType(native_function->signature().ret_type()->id()); llvm::Value* value = generator_->AddFunctionCall(native_function->pc_name(), ret_type, params); @@ -595,12 +607,14 @@ void LLVMGenerator::Visitor::Visit(const NullableNeverFuncDex& dex) { ADD_VISITOR_TRACE("visit NullableNever base function " + dex.func_descriptor()->name()); LLVMTypes* types = generator_->types_.get(); + const NativeFunction *native_function = dex.native_function(); + // build function params along with validity. - auto params = BuildParams(dex.function_holder().get(), dex.args(), true); + auto params = BuildParams(dex.function_holder().get(), dex.args(), true, + native_function->needs_context()); - const NativeFunction* native_function = dex.native_function(); - llvm::Type* ret_type = types->IRType(native_function->signature().ret_type()->id()); - llvm::Value* value = + llvm::Type *ret_type = types->IRType(native_function->signature().ret_type()->id()); + llvm::Value *value = generator_->AddFunctionCall(native_function->pc_name(), ret_type, params); result_.reset(new LValue(value)); } @@ -611,17 +625,19 @@ void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex& dex) { llvm::IRBuilder<>& builder = ir_builder(); LLVMTypes* types = generator_->types_.get(); + const NativeFunction *native_function = dex.native_function(); + // build function params along with validity. - auto params = BuildParams(dex.function_holder().get(), dex.args(), true); + auto params = BuildParams(dex.function_holder().get(), dex.args(), true, + native_function->needs_context()); // add an extra arg for validity (alloced on stack). llvm::AllocaInst* result_valid_ptr = new llvm::AllocaInst(types->i8_type(), 0, "result_valid", entry_block_); params.push_back(result_valid_ptr); - const NativeFunction* native_function = dex.native_function(); - llvm::Type* ret_type = types->IRType(native_function->signature().ret_type()->id()); - llvm::Value* value = + llvm::Type *ret_type = types->IRType(native_function->signature().ret_type()->id()); + llvm::Value *value = generator_->AddFunctionCall(native_function->pc_name(), ret_type, params); // load the result validity and truncate to i1. @@ -856,10 +872,11 @@ LValuePtr LLVMGenerator::Visitor::BuildValueAndValidity(const ValueValidityPair& return std::make_shared<LValue>(value, length, validity); } -std::vector<llvm::Value*> LLVMGenerator::Visitor::BuildParams( - FunctionHolder* holder, const ValueValidityPairVector& args, bool with_validity) { - LLVMTypes* types = generator_->types_.get(); - std::vector<llvm::Value*> params; +std::vector<llvm::Value *> LLVMGenerator::Visitor::BuildParams( + FunctionHolder *holder, const ValueValidityPairVector &args, bool with_validity, + bool with_context) { + LLVMTypes *types = generator_->types_.get(); + std::vector<llvm::Value *> params; // if the function has holder, add the holder pointer first. if (holder != nullptr) { @@ -887,6 +904,12 @@ std::vector<llvm::Value*> LLVMGenerator::Visitor::BuildParams( params.push_back(validity_expr); } } + + // add error holder if function can return error + if (with_context) { + params.push_back(arg_context_ptr_); + } + return params; } @@ -984,6 +1007,9 @@ std::string LLVMGenerator::ReplaceFormatInTrace(const std::string& in_msg, // float fmt = "%lf"; *print_fn = "print_double"; + } else if (type->isPointerTy()) { + // string + fmt = "%s"; } else { DCHECK(0); } diff --git a/cpp/src/gandiva/llvm_generator.h b/cpp/src/gandiva/llvm_generator.h index fb158b4..5d60617 100644 --- a/cpp/src/gandiva/llvm_generator.h +++ b/cpp/src/gandiva/llvm_generator.h @@ -30,6 +30,7 @@ #include "gandiva/annotator.h" #include "gandiva/compiled_expr.h" #include "gandiva/configuration.h" +#include "gandiva/execution_context.h" #include "gandiva/dex_visitor.h" #include "gandiva/engine.h" #include "gandiva/function_registry.h" @@ -73,6 +74,7 @@ class LLVMGenerator { /// Visitor to generate the code for a decomposed expression. class Visitor : public DexVisitor { public: +<<<<<<< HEAD Visitor(LLVMGenerator* generator, llvm::Function* function, llvm::BasicBlock* entry_block, llvm::Value* arg_addrs, llvm::Value* arg_local_bitmaps, llvm::Value* loop_var); @@ -90,6 +92,26 @@ class LLVMGenerator { void Visit(const IfDex& dex) override; void Visit(const BooleanAndDex& dex) override; void Visit(const BooleanOrDex& dex) override; +======= + Visitor(LLVMGenerator *generator, llvm::Function *function, + llvm::BasicBlock *entry_block, llvm::Value *arg_addrs, + llvm::Value *arg_local_bitmaps, llvm::Value *arg_context_ptr, + llvm::Value *loop_var); + + void Visit(const VectorReadValidityDex &dex) override; + void Visit(const VectorReadFixedLenValueDex &dex) override; + void Visit(const VectorReadVarLenValueDex &dex) override; + void Visit(const LocalBitMapValidityDex &dex) override; + void Visit(const TrueDex &dex) override; + void Visit(const FalseDex &dex) override; + void Visit(const LiteralDex &dex) override; + void Visit(const NonNullableFuncDex &dex) override; + void Visit(const NullableNeverFuncDex &dex) override; + void Visit(const NullableInternalFuncDex &dex) override; + void Visit(const IfDex &dex) override; + void Visit(const BooleanAndDex &dex) override; + void Visit(const BooleanOrDex &dex) override; +>>>>>>> f50cf372... [Gandiva]Error handling support. LValuePtr result() { return result_; } @@ -107,9 +129,15 @@ class LLVMGenerator { LValuePtr BuildValueAndValidity(const ValueValidityPair& pair); // Generate code to build the params. +<<<<<<< HEAD std::vector<llvm::Value*> BuildParams(FunctionHolder* holder, const ValueValidityPairVector& args, bool with_validity); +======= + std::vector<llvm::Value *> BuildParams(FunctionHolder *holder, + const ValueValidityPairVector &args, + bool with_validity, bool with_context); +>>>>>>> f50cf372... [Gandiva]Error handling support. // Switch to the entry_block and get reference of the validity/value/offsets buffer llvm::Value* GetBufferReference(int idx, BufferType buffer_type, FieldPtr field); @@ -122,11 +150,20 @@ class LLVMGenerator { LLVMGenerator* generator_; LValuePtr result_; +<<<<<<< HEAD llvm::Function* function_; llvm::BasicBlock* entry_block_; llvm::Value* arg_addrs_; llvm::Value* arg_local_bitmaps_; llvm::Value* loop_var_; +======= + llvm::Function *function_; + llvm::BasicBlock *entry_block_; + llvm::Value *arg_addrs_; + llvm::Value *arg_local_bitmaps_; + llvm::Value *arg_context_ptr_; + llvm::Value *loop_var_; +>>>>>>> f50cf372... [Gandiva]Error handling support. }; // Generate the code for one expression, with the output of the expression going to diff --git a/cpp/src/gandiva/llvm_generator_test.cc b/cpp/src/gandiva/llvm_generator_test.cc index 5e9d99a..e50e8f8 100644 --- a/cpp/src/gandiva/llvm_generator_test.cc +++ b/cpp/src/gandiva/llvm_generator_test.cc @@ -112,7 +112,7 @@ TEST_F(TestLLVMGenerator, TestAdd) { reinterpret_cast<uint8_t*>(a1), reinterpret_cast<uint8_t*>(&in_bitmap), reinterpret_cast<uint8_t*>(out), reinterpret_cast<uint8_t*>(&out_bitmap), }; - eval_func(addrs, nullptr, num_records); + eval_func(addrs, nullptr, 0, num_records); uint32_t expected[] = {6, 8, 10, 12}; for (int i = 0; i < num_records; i++) { @@ -179,7 +179,7 @@ TEST_F(TestLLVMGenerator, TestNullInternal) { reinterpret_cast<uint8_t*>(&local_bitmap), }; - eval_func(addrs, local_bitmap_addrs, num_records); + eval_func(addrs, local_bitmap_addrs, 0, num_records); uint32_t expected_value[] = {0, 1, 0, 2}; bool expected_validity[] = {false, true, false, true}; diff --git a/cpp/src/gandiva/native_function.h b/cpp/src/gandiva/native_function.h index ee22629..744904f 100644 --- a/cpp/src/gandiva/native_function.h +++ b/cpp/src/gandiva/native_function.h @@ -42,25 +42,25 @@ class NativeFunction { const FunctionSignature& signature() const { return signature_; } std::string pc_name() const { return pc_name_; } ResultNullableType result_nullable_type() const { return result_nullable_type_; } - bool param_null_safe() const { return param_null_safe_; } bool needs_holder() const { return needs_holder_; } + bool needs_context() const { return needs_context_; } private: - NativeFunction(const std::string& base_name, const DataTypeVector& param_types, - DataTypePtr ret_type, bool param_null_safe, - const ResultNullableType& result_nullable_type, - const std::string& pc_name, bool needs_holder = false) + NativeFunction(const std::string &base_name, const DataTypeVector ¶m_types, + DataTypePtr ret_type, const ResultNullableType &result_nullable_type, + const std::string &pc_name, bool needs_holder = false, + bool needs_context = false) : signature_(base_name, param_types, ret_type), - param_null_safe_(param_null_safe), needs_holder_(needs_holder), + needs_context_(needs_context), result_nullable_type_(result_nullable_type), pc_name_(pc_name) {} FunctionSignature signature_; /// attributes - bool param_null_safe_; bool needs_holder_; + bool needs_context_; ResultNullableType result_nullable_type_; /// pre-compiled function name. diff --git a/cpp/src/gandiva/precompiled/CMakeLists.txt b/cpp/src/gandiva/precompiled/CMakeLists.txt index 4857314..e82b4d4 100644 --- a/cpp/src/gandiva/precompiled/CMakeLists.txt +++ b/cpp/src/gandiva/precompiled/CMakeLists.txt @@ -34,8 +34,8 @@ foreach(SRC_FILE ${PRECOMPILED_SRCS}) set(BC_FILE ${CMAKE_CURRENT_BINARY_DIR}/${SRC_BASE}.bc) add_custom_command( OUTPUT ${BC_FILE} - COMMAND ${CLANG_EXECUTABLE} -I${CMAKE_SOURCE_DIR}/src - -std=c++11 -emit-llvm -O2 -c ${ABSOLUTE_SRC} -o ${BC_FILE} + COMMAND ${CLANG_EXECUTABLE} + -D GDV_HELPERS -std=c++11 -emit-llvm -O2 -c ${ABSOLUTE_SRC} -o ${BC_FILE} DEPENDS ${SRC_FILE}) list(APPEND BC_FILES ${BC_FILE}) endforeach() @@ -57,4 +57,4 @@ add_precompiled_unit_test(time_test.cc time.cc timestamp_arithmetic.cc) add_precompiled_unit_test(hash_test.cc hash.cc) add_precompiled_unit_test(sample_test.cc sample.cc) add_precompiled_unit_test(string_ops_test.cc string_ops.cc) -add_precompiled_unit_test(arithmetic_ops_test.cc arithmetic_ops.cc) +add_precompiled_unit_test(arithmetic_ops_test.cc arithmetic_ops.cc ../codegen/execution_context.cc) diff --git a/cpp/src/gandiva/precompiled/arithmetic_ops.cc b/cpp/src/gandiva/precompiled/arithmetic_ops.cc index 7ad1e35..aa337eb 100644 --- a/cpp/src/gandiva/precompiled/arithmetic_ops.cc +++ b/cpp/src/gandiva/precompiled/arithmetic_ops.cc @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +#include "../codegen/execution_context.h" + extern "C" { +#include "./context_helper.h" #include "./types.h" // Expand inner macro for all numeric types. @@ -161,12 +164,14 @@ NUMERIC_BOOL_DATE_FUNCTION(IS_NOT_DISTINCT_FROM) #define DIVIDE_NULL_INTERNAL(TYPE) \ FORCE_INLINE \ TYPE divide_##TYPE##_##TYPE(TYPE in1, boolean is_valid1, TYPE in2, boolean is_valid2, \ - bool *out_valid) { \ + int64 execution_context, bool* out_valid) { \ *out_valid = false; \ if (!is_valid1 || !is_valid2) { \ return 0; \ } \ if (in2 == 0) { \ + char const* err_msg = "divide by zero error"; \ + set_error_msg(execution_context, err_msg); \ return 0; \ } \ *out_valid = true; \ diff --git a/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc b/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc index 4c4573d..d7fded8 100644 --- a/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc +++ b/cpp/src/gandiva/precompiled/arithmetic_ops_test.cc @@ -17,6 +17,7 @@ #include <gtest/gtest.h> #include "gandiva/precompiled/types.h" +#include "../execution_context.h" namespace gandiva { @@ -36,17 +37,18 @@ TEST(TestArithmeticOps, TestMod) { EXPECT_EQ(mod_int64_int32(10, 0), 10); } TEST(TestArithmeticOps, TestDivide) { boolean is_valid; - int64 out = divide_int64_int64(10, true, 0, true, &is_valid); + gandiva::helpers::ExecutionContext error_holder; + int64 out = divide_int64_int64(10, true, 0, true, (int64)&error_holder, &is_valid); EXPECT_EQ(out, 0); EXPECT_EQ(is_valid, false); + EXPECT_EQ(error_holder.has_error(), true); + EXPECT_EQ(error_holder.get_error(), "divide by zero error"); - out = divide_int64_int64(10, true, 2, false, &is_valid); - EXPECT_EQ(out, 0); - EXPECT_EQ(is_valid, false); - - out = divide_int64_int64(10, true, 2, true, &is_valid); + gandiva::helpers::ExecutionContext error_holder1; + out = divide_int64_int64(10, true, 2, true, (int64)&error_holder, &is_valid); EXPECT_EQ(out, 5); EXPECT_EQ(is_valid, true); + EXPECT_EQ(error_holder1.has_error(), false); } } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/context_helper.h b/cpp/src/gandiva/precompiled/context_helper.h new file mode 100644 index 0000000..62bae62 --- /dev/null +++ b/cpp/src/gandiva/precompiled/context_helper.h @@ -0,0 +1,25 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_CONTEXT_HELPER_H +#define GANDIVA_CONTEXT_HELPER_H + +#include "../codegen/execution_context.h" + +void set_error_msg(int64_t context_ptr, char const* err_msg) { + gandiva::helpers::ExecutionContext* execution_context_ptr = + reinterpret_cast<gandiva::helpers::ExecutionContext*>(context_ptr); + (execution_context_ptr)->set_error_msg(err_msg); +} +#endif diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h index ab7e9c6..168f93a 100644 --- a/cpp/src/gandiva/precompiled/types.h +++ b/cpp/src/gandiva/precompiled/types.h @@ -123,7 +123,7 @@ int32 mem_compare(const char* left, int32 left_len, const char* right, int32 rig int32 mod_int64_int32(int64 left, int32 right); int64 divide_int64_int64(int64 in1, boolean is_valid1, int64 in2, boolean is_valid2, - bool *out_valid); + int64 error_holder, bool *out_valid); bool starts_with_utf8_utf8(const char *data, int32 data_len, const char *prefix, int32 prefix_len); diff --git a/cpp/src/gandiva/status.h b/cpp/src/gandiva/status.h index dde0f93..68e1230 100644 --- a/cpp/src/gandiva/status.h +++ b/cpp/src/gandiva/status.h @@ -63,6 +63,7 @@ enum class StatusCode : char { CodeGenError = 2, ArrowError = 3, ExpressionValidationError = 4, + ExecutionError = 5, }; class Status { @@ -107,6 +108,10 @@ class Status { return Status(StatusCode::ExpressionValidationError, msg); } + static Status ExecutionError(const std::string& msg) { + return Status(StatusCode::ExecutionError, msg); + } + // Returns true if the status indicates success. bool ok() const { return (state_ == NULL); } @@ -120,6 +125,8 @@ class Status { return code() == StatusCode::ExpressionValidationError; } + bool IsExecutionError() const { return code() == StatusCode::ExecutionError; } + // Return a string representation of this status suitable for printing. // Returns the string "OK" for success. std::string ToString() const; @@ -234,6 +241,9 @@ inline std::string Status::CodeAsString() const { case StatusCode::ExpressionValidationError: type = "ExpressionValidationError"; break; + case StatusCode::ExecutionError: + type = "ExecutionError"; + break; default: type = "Unknown"; break; diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index 743bab2..cda5b75 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -583,11 +583,9 @@ TEST_F(TestProjector, TestDivideZero) { EXPECT_TRUE(status.ok()) << status.message(); // Create a row-batch with some sample data - int num_records = 4; - auto array0 = MakeArrowArrayInt32({2, 3, 4, 5}, {true, true, true, true}); - auto array1 = MakeArrowArrayInt32({1, 2, 2, 0}, {true, true, false, true}); - // expected output - auto exp_div = MakeArrowArrayInt32({2, 1, 0, 0}, {true, true, false, false}); + int num_records = 5; + auto array0 = MakeArrowArrayInt32({2, 3, 4, 5, 6}, {true, true, true, true, true}); + auto array1 = MakeArrowArrayInt32({1, 2, 2, 0, 0}, {true, true, false, true, true}); // prepare input record batch auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); @@ -595,10 +593,26 @@ TEST_F(TestProjector, TestDivideZero) { // Evaluate expression arrow::ArrayVector outputs; status = projector->Evaluate(*in_batch, pool_, &outputs); - EXPECT_TRUE(status.ok()) << status.message(); + EXPECT_EQ(status.code(), StatusCode::ExecutionError); + std::string expected_error = "divide by zero error"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); + + // Testing for second batch that has no error should succeed. + num_records = 5; + array0 = MakeArrowArrayInt32({2, 3, 4, 5, 6}, {true, true, true, true, true}); + array1 = MakeArrowArrayInt32({1, 2, 2, 1, 1}, {true, true, false, true, true}); + + // prepare input record batch + in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + // expected output + auto exp = MakeArrowArrayInt32({2, 1, 2, 5, 6}, {true, true, false, true, true}); + + // Evaluate expression + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); // Validate results - EXPECT_ARROW_ARRAY_EQUALS(exp_div, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); } TEST_F(TestProjector, TestModZero) { diff --git a/cpp/src/gandiva/tests/to_string_test.cc b/cpp/src/gandiva/tests/to_string_test.cc index 38c2783..97919d4 100644 --- a/cpp/src/gandiva/tests/to_string_test.cc +++ b/cpp/src/gandiva/tests/to_string_test.cc @@ -44,12 +44,12 @@ TEST_F(TestToString, TestAll) { auto literal_node = TreeExprBuilder::MakeLiteral((uint64_t)100); auto literal_expr = TreeExprBuilder::MakeExpression(literal_node, arrow::field("r", int64())); - CHECK_EXPR_TO_STRING(literal_expr, "(uint64) 100"); + CHECK_EXPR_TO_STRING(literal_expr, "(const uint64) 100"); auto f0 = arrow::field("f0", float64()); auto f0_node = TreeExprBuilder::MakeField(f0); auto f0_expr = TreeExprBuilder::MakeExpression(f0_node, f0); - CHECK_EXPR_TO_STRING(f0_expr, "double"); + CHECK_EXPR_TO_STRING(f0_expr, "(double) f0"); auto f1 = arrow::field("f1", int64()); auto f2 = arrow::field("f2", int64()); @@ -57,7 +57,7 @@ TEST_F(TestToString, TestAll) { auto f2_node = TreeExprBuilder::MakeField(f2); auto add_node = TreeExprBuilder::MakeFunction("add", {f1_node, f2_node}, int64()); auto add_expr = TreeExprBuilder::MakeExpression(add_node, f1); - CHECK_EXPR_TO_STRING(add_expr, "int64 add(int64, int64)"); + CHECK_EXPR_TO_STRING(add_expr, "int64 add((int64) f1, (int64) f2)"); auto cond_node = TreeExprBuilder::MakeFunction( "lesser_than", {f0_node, TreeExprBuilder::MakeLiteral(static_cast<float>(0))}, @@ -69,7 +69,7 @@ TEST_F(TestToString, TestAll) { auto if_expr = TreeExprBuilder::MakeExpression(if_node, f1); CHECK_EXPR_TO_STRING( if_expr, - "if (bool lesser_than(double, (float) 0 raw(0))) { int64 } else { int64 }"); + "if (bool lesser_than((double) f0, (const float) 0 raw(0))) { (int64) f1 } else { (int64) f2 }"); auto f1_gt_100 = TreeExprBuilder::MakeFunction("greater_than", {f1_node, literal_node}, boolean()); @@ -80,7 +80,7 @@ TEST_F(TestToString, TestAll) { TreeExprBuilder::MakeExpression(and_node, arrow::field("f0", boolean())); CHECK_EXPR_TO_STRING( and_expr, - "bool greater_than(int64, (uint64) 100) && bool equals(int64, (uint64) 100)"); + "bool greater_than((int64) f1, (const uint64) 100) && bool equals((int64) f2, (const uint64) 100)"); } } // namespace gandiva diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java index 5892b73..2c18b75 100644 --- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java +++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java @@ -46,6 +46,7 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.IntStream; import static org.junit.Assert.assertEquals; @@ -248,6 +249,124 @@ public class ProjectorTest extends BaseEvaluatorTest { } @Test + public void testEvaluateDivZero() throws GandivaException, Exception { + Field a = Field.nullable("a", int32); + Field b = Field.nullable("b", int32); + List<Field> args = Lists.newArrayList(a, b); + + Field retType = Field.nullable("c", int32); + ExpressionTree root = TreeBuilder.makeExpression("divide", args, retType); + + List<ExpressionTree> exprs = Lists.newArrayList(root); + + Schema schema = new Schema(args); + Projector eval = Projector.make(schema, exprs); + + int numRows = 2; + byte[] validity = new byte[]{(byte) 255}; + // second half is "undefined" + int[] values_a = new int[]{2, 2}; + int[] values_b = new int[]{1, 0}; + + ArrowBuf validitya = buf(validity); + ArrowBuf valuesa = intBuf(values_a); + ArrowBuf validityb = buf(validity); + ArrowBuf valuesb = intBuf(values_b); + ArrowRecordBatch batch = new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)), + Lists.newArrayList(validitya, valuesa, validityb, valuesb)); + + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(intVector); + boolean exceptionThrown = false; + try { + eval.evaluate(batch, output); + } catch (GandivaException e) { + Assert.assertTrue(e.getMessage().contains("divide by zero")); + exceptionThrown = true; + } + Assert.assertTrue(exceptionThrown); + + // free buffers + releaseRecordBatch(batch); + releaseValueVectors(output); + eval.close(); + } + + @Test + public void testDivZeroParallel() throws GandivaException, InterruptedException { + Field a = Field.nullable("a", int32); + Field b = Field.nullable("b", int32); + Field c = Field.nullable("c", int32); + List<Field> cols = Lists.newArrayList(a, b); + Schema s = new Schema(cols); + + List<Field> args = Lists.newArrayList(a, b); + + + ExpressionTree expr = TreeBuilder.makeExpression("divide", args, c); + List<ExpressionTree> exprs = Lists.newArrayList(expr); + + ExecutorService executors = Executors.newFixedThreadPool(16); + + AtomicInteger errorCount = new AtomicInteger(0); + AtomicInteger errorCountExp = new AtomicInteger(0); + // pre-build the projector so that same projector is used for all executions. + Projector.make(s, exprs); + + IntStream.range(0, 1000).forEach(i -> { + executors.submit(() -> { + try { + Projector evaluator = Projector.make(s, exprs); + int numRows = 2; + byte[] validity = new byte[]{(byte) 255}; + int[] values_a = new int[]{2, 2}; + int[] values_b; + if (i%2 == 0) { + errorCountExp.incrementAndGet(); + values_b = new int[]{1, 0}; + } else { + values_b = new int[]{1, 1}; + } + + ArrowBuf validitya = buf(validity); + ArrowBuf valuesa = intBuf(values_a); + ArrowBuf validityb = buf(validity); + ArrowBuf valuesb = intBuf(values_b); + ArrowRecordBatch batch = new ArrowRecordBatch( + numRows, + Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, + 0)), + Lists.newArrayList(validitya, valuesa, validityb, valuesb)); + + IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator); + intVector.allocateNew(numRows); + + List<ValueVector> output = new ArrayList<ValueVector>(); + output.add(intVector); + try { + evaluator.evaluate(batch, output); + } catch (GandivaException e) { + errorCount.incrementAndGet(); + } + // free buffers + releaseRecordBatch(batch); + releaseValueVectors(output); + evaluator.close(); + } catch (GandivaException e) { + } + }); + }); + executors.shutdown(); + executors.awaitTermination(100, java.util.concurrent.TimeUnit.SECONDS); + Assert.assertEquals(errorCountExp.intValue(), errorCount.intValue()); + } + + @Test public void testAdd3() throws GandivaException, Exception { Field x = Field.nullable("x", int32); Field N2x = Field.nullable("N2x", int32);
