IMPALA-1430,IMPALA-4878,IMPALA-4879: codegen native UDAs This uses the existing infrastructure for codegening builtin UDAs and for codegening calls to UDFs. GetUdf() is refactored to support both UDFs and UDAs.
IR UDAs are still not allowed by the frontend. It's unclear if we want to enable them going forward because of the difficulties in testing and supporting IR UDFs/UDAs. This also fixes some bugs with the Get*Type() methods of FunctionContext. GetArgType() and related methods now always return the logical input types of the aggregate function. Getting the tests to pass required fixing IMPALA-4878 because they called GetIntermediateType(). Testing: test_udfs.py tests UDAs with codegen enabled and disabled. Added some asserts to test UDAs to check that the correct types are passed in via the FunctionContext. Change-Id: Id1708eaa96eb76fb9bec5eeabf209f81c88eec2f Reviewed-on: http://gerrit.cloudera.org:8080/5161 Reviewed-by: Dan Hecht <[email protected]> Tested-by: Impala Public Jenkins Project: http://git-wip-us.apache.org/repos/asf/incubator-impala/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-impala/commit/d2d3f4c1 Tree: http://git-wip-us.apache.org/repos/asf/incubator-impala/tree/d2d3f4c1 Diff: http://git-wip-us.apache.org/repos/asf/incubator-impala/diff/d2d3f4c1 Branch: refs/heads/master Commit: d2d3f4c1a6eefb3f1335da8bd6791fbebd63d98b Parents: 1335af3 Author: Tim Armstrong <[email protected]> Authored: Fri Nov 18 13:34:15 2016 -0800 Committer: Impala Public Jenkins <[email protected]> Committed: Fri Feb 10 02:18:32 2017 +0000 ---------------------------------------------------------------------- be/src/codegen/codegen-anyval.cc | 4 +- be/src/codegen/codegen-anyval.h | 4 +- be/src/codegen/llvm-codegen.cc | 113 +++++++++++++- be/src/codegen/llvm-codegen.h | 25 ++- be/src/exec/partitioned-aggregation-node.cc | 26 +--- be/src/exprs/agg-fn-evaluator.cc | 64 +++++--- be/src/exprs/agg-fn-evaluator.h | 21 +-- be/src/exprs/anyval-util.cc | 34 ++-- be/src/exprs/anyval-util.h | 2 + be/src/exprs/expr.cc | 8 +- be/src/exprs/expr.h | 9 +- be/src/exprs/scalar-fn-call.cc | 154 +++---------------- be/src/exprs/scalar-fn-call.h | 5 +- be/src/exprs/timestamp-functions.cc | 4 +- be/src/runtime/types.cc | 6 + be/src/runtime/types.h | 2 + be/src/testutil/test-udas.cc | 132 ++++++++++++++-- be/src/testutil/test-udas.h | 1 + be/src/udf/udf-internal.h | 6 + be/src/udf/udf-ir.cc | 4 + be/src/udf/udf.h | 18 ++- common/thrift/Exprs.thrift | 5 + .../apache/impala/analysis/AggregateInfo.java | 10 +- .../impala/analysis/FunctionCallExpr.java | 61 ++++++-- .../functional-query/queries/QueryTest/uda.test | 19 +++ tests/query_test/test_udfs.py | 10 ++ 26 files changed, 500 insertions(+), 247 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/codegen/codegen-anyval.cc ---------------------------------------------------------------------- diff --git a/be/src/codegen/codegen-anyval.cc b/be/src/codegen/codegen-anyval.cc index bd27409..a778812 100644 --- a/be/src/codegen/codegen-anyval.cc +++ b/be/src/codegen/codegen-anyval.cc @@ -67,7 +67,7 @@ Type* CodegenAnyVal::GetLoweredType(LlvmCodeGen* cg, const ColumnType& type) { } } -Type* CodegenAnyVal::GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& type) { +PointerType* CodegenAnyVal::GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& type) { return GetLoweredType(cg, type)->getPointerTo(); } @@ -116,7 +116,7 @@ Type* CodegenAnyVal::GetUnloweredType(LlvmCodeGen* cg, const ColumnType& type) { return result; } -Type* CodegenAnyVal::GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& type) { +PointerType* CodegenAnyVal::GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& type) { return GetUnloweredType(cg, type)->getPointerTo(); } http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/codegen/codegen-anyval.h ---------------------------------------------------------------------- diff --git a/be/src/codegen/codegen-anyval.h b/be/src/codegen/codegen-anyval.h index c07f3eb..13494ac 100644 --- a/be/src/codegen/codegen-anyval.h +++ b/be/src/codegen/codegen-anyval.h @@ -95,7 +95,7 @@ class CodegenAnyVal { /// Returns the lowered AnyVal pointer type associated with 'type'. /// E.g.: TYPE_BOOLEAN => i16* - static llvm::Type* GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& type); + static llvm::PointerType* GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& type); /// Returns the unlowered AnyVal type associated with 'type'. /// E.g.: TYPE_BOOLEAN => %"struct.impala_udf::BooleanVal" @@ -103,7 +103,7 @@ class CodegenAnyVal { /// Returns the unlowered AnyVal pointer type associated with 'type'. /// E.g.: TYPE_BOOLEAN => %"struct.impala_udf::BooleanVal"* - static llvm::Type* GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& type); + static llvm::PointerType* GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& type); /// Return the constant type-lowered value corresponding to a null *Val. /// E.g.: passing TYPE_DOUBLE (corresponding to the lowered DoubleVal { i8, double }) http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/codegen/llvm-codegen.cc ---------------------------------------------------------------------- diff --git a/be/src/codegen/llvm-codegen.cc b/be/src/codegen/llvm-codegen.cc index 3d730d5..6cf43f6 100644 --- a/be/src/codegen/llvm-codegen.cc +++ b/be/src/codegen/llvm-codegen.cc @@ -70,6 +70,7 @@ #include "util/hdfs-util.h" #include "util/path-builder.h" #include "util/runtime-profile-counters.h" +#include "util/symbols-util.h" #include "util/test-info.h" #include "common/names.h" @@ -766,8 +767,116 @@ Function* LlvmCodeGen::FnPrototype::GeneratePrototype( return fn; } -int LlvmCodeGen::ReplaceCallSites(Function* caller, Function* new_fn, - const string& target_name) { +Status LlvmCodeGen::LoadFunction(const TFunction& fn, const std::string& symbol, + const ColumnType* return_type, const std::vector<ColumnType>& arg_types, + int num_fixed_args, bool has_varargs, Function** llvm_fn, + LibCacheEntry** cache_entry) { + DCHECK_GE(arg_types.size(), num_fixed_args); + DCHECK(has_varargs || arg_types.size() == num_fixed_args); + DCHECK(!has_varargs || arg_types.size() > num_fixed_args); + // from_utc_timestamp() and to_utc_timestamp() have inline ASM that cannot be JIT'd. + // TimestampFunctions::AddSub() contains a try/catch which doesn't work in JIT'd + // code. Always use the interpreted version of these functions. + // TODO: fix these built-in functions so we don't need 'broken_builtin' below. + bool broken_builtin = fn.name.function_name == "from_utc_timestamp" + || fn.name.function_name == "to_utc_timestamp" + || symbol.find("AddSub") != string::npos; + if (fn.binary_type == TFunctionBinaryType::NATIVE + || (fn.binary_type == TFunctionBinaryType::BUILTIN && broken_builtin)) { + // In this path, we are calling a precompiled native function, either a UDF + // in a .so or a builtin using the UDF interface. + void* fn_ptr; + Status status = LibCache::instance()->GetSoFunctionPtr( + fn.hdfs_location, symbol, &fn_ptr, cache_entry); + if (!status.ok() && fn.binary_type == TFunctionBinaryType::BUILTIN) { + // Builtins symbols should exist unless there is a version mismatch. + status.AddDetail( + ErrorMsg(TErrorCode::MISSING_BUILTIN, fn.name.function_name, symbol).msg()); + } + RETURN_IF_ERROR(status); + DCHECK(fn_ptr != NULL); + + // Per the x64 ABI, DecimalVals are returned via a DecimalVal* output argument. + // So, the return type is void. + bool is_decimal = return_type != NULL && return_type->type == TYPE_DECIMAL; + Type* llvm_return_type = return_type == NULL || is_decimal ? + void_type() : + CodegenAnyVal::GetLoweredType(this, *return_type); + + // Convert UDF function pointer to Function*. Start by creating a function + // prototype for it. + FnPrototype prototype(this, symbol, llvm_return_type); + + if (is_decimal) { + // Per the x64 ABI, DecimalVals are returned via a DecmialVal* output argument + Type* output_type = CodegenAnyVal::GetUnloweredPtrType(this, *return_type); + prototype.AddArgument("output", output_type); + } + + // The "FunctionContext*" argument. + prototype.AddArgument("ctx", GetPtrType("class.impala_udf::FunctionContext")); + + // The "fixed" arguments for the UDF function, followed by the variable arguments, + // if any. + for (int i = 0; i < num_fixed_args; ++i) { + Type* arg_type = CodegenAnyVal::GetUnloweredPtrType(this, arg_types[i]); + prototype.AddArgument(Substitute("fixed_arg_$0", i), arg_type); + } + + if (has_varargs) { + prototype.AddArgument("num_var_arg", GetType(TYPE_INT)); + // Get the vararg type from the first vararg. + prototype.AddArgument( + "var_arg", CodegenAnyVal::GetUnloweredPtrType(this, arg_types[num_fixed_args])); + } + + // Create a Function* with the generated type. This is only a function + // declaration, not a definition, since we do not create any basic blocks or + // instructions in it. + *llvm_fn = prototype.GeneratePrototype(NULL, NULL, false); + + // Associate the dynamically loaded function pointer with the Function* we defined. + // This tells LLVM where the compiled function definition is located in memory. + execution_engine_->addGlobalMapping(*llvm_fn, fn_ptr); + } else if (fn.binary_type == TFunctionBinaryType::BUILTIN) { + // In this path, we're running a builtin with the UDF interface. The IR is + // in the llvm module. Builtin functions may use Expr::GetConstant(). Clone the + // function so that we can replace constants in the copied function. + *llvm_fn = GetFunction(symbol, true); + if (*llvm_fn == NULL) { + // Builtins symbols should exist unless there is a version mismatch. + return Status(Substitute("Builtin '$0' with symbol '$1' does not exist. Verify " + "that all your impalads are the same version.", + fn.name.function_name, symbol)); + } + // Rename the function to something more readable than the mangled name. + string demangled_name = SymbolsUtil::DemangleNoArgs((*llvm_fn)->getName().str()); + (*llvm_fn)->setName(demangled_name); + } else { + // We're running an IR UDF. + DCHECK_EQ(fn.binary_type, TFunctionBinaryType::IR); + + string local_path; + RETURN_IF_ERROR(LibCache::instance()->GetLocalLibPath( + fn.hdfs_location, LibCache::TYPE_IR, &local_path)); + // Link the UDF module into this query's main module so the UDF's functions are + // available in the main module. + RETURN_IF_ERROR(LinkModule(local_path)); + + *llvm_fn = GetFunction(symbol, true); + if (*llvm_fn == NULL) { + return Status(Substitute("Unable to load function '$0' from LLVM module '$1'", + symbol, fn.hdfs_location)); + } + // Rename the function to something more readable than the mangled name. + string demangled_name = SymbolsUtil::DemangleNoArgs((*llvm_fn)->getName().str()); + (*llvm_fn)->setName(demangled_name); + } + return Status::OK(); +} + +int LlvmCodeGen::ReplaceCallSites( + Function* caller, Function* new_fn, const string& target_name) { DCHECK(!is_compiled_); DCHECK(caller->getParent() == module_); DCHECK(caller != NULL); http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/codegen/llvm-codegen.h ---------------------------------------------------------------------- diff --git a/be/src/codegen/llvm-codegen.h b/be/src/codegen/llvm-codegen.h index 9615664..d51faab 100644 --- a/be/src/codegen/llvm-codegen.h +++ b/be/src/codegen/llvm-codegen.h @@ -294,6 +294,23 @@ class LlvmCodeGen { /// functions. Status FinalizeModule(); + /// Loads a native or IR function 'fn' with symbol 'symbol' from the builtins or + /// an external library and puts the result in *llvm_fn. *llvm_fn can be safely + /// modified in place, because it is either newly generated or cloned. The caller must + /// call FinalizeFunction() on 'llvm_fn' once it is done modifying it. The function has + /// return type 'return_type' (void if 'return_type' is NULL) and input argument types + /// 'arg_types'. The first 'num_fixed_args' arguments are fixed arguments, and the + /// remaining arguments are varargs. 'has_varargs' indicates whether the function + /// accepts varargs. If 'has_varargs' is true, there must be at least one vararg. If + /// the function is loaded from a library, 'cache_entry' is updated to point to the + /// library containing the function. If 'cache_entry' is set to a non-NULL value by + /// this function, the caller must call LibCache::DecrementUseCount() on it when done + /// using the function. + Status LoadFunction(const TFunction& fn, const std::string& symbol, + const ColumnType* return_type, const std::vector<ColumnType>& arg_types, + int num_fixed_args, bool has_varargs, llvm::Function** llvm_fn, + LibCacheEntry** cache_entry); + /// Replaces all instructions in 'caller' that call 'target_name' with a call /// instruction to 'new_fn'. Returns the number of call sites updated. /// @@ -485,10 +502,6 @@ class LlvmCodeGen { llvm::Value* CodegenArrayAt( LlvmBuilder*, llvm::Value* array, int idx, const char* name = ""); - /// Loads a module at 'file' and links it to the module associated with - /// this LlvmCodeGen object. The module must be on the local filesystem. - Status LinkModule(const std::string& file); - /// If there are more than this number of expr trees (or functions that evaluate /// expressions), avoid inlining avoid inlining for the exprs exceeding this threshold. static const int CODEGEN_INLINE_EXPRS_THRESHOLD = 100; @@ -538,6 +551,10 @@ class LlvmCodeGen { Status LoadModuleFromMemory(std::unique_ptr<llvm::MemoryBuffer> module_ir_buf, std::string module_name, std::unique_ptr<llvm::Module>* module); + /// Loads a module at 'file' and links it to the module associated with + /// this LlvmCodeGen object. The module must be on the local filesystem. + Status LinkModule(const std::string& file); + /// Strip global constructors and destructors from an LLVM module. We never run them /// anyway (they must be explicitly invoked) so it is dead code. static void StripGlobalCtorsDtors(llvm::Module* module); http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exec/partitioned-aggregation-node.cc ---------------------------------------------------------------------- diff --git a/be/src/exec/partitioned-aggregation-node.cc b/be/src/exec/partitioned-aggregation-node.cc index 6cc36be..7c75d01 100644 --- a/be/src/exec/partitioned-aggregation-node.cc +++ b/be/src/exec/partitioned-aggregation-node.cc @@ -1721,22 +1721,8 @@ Status PartitionedAggregationNode::CodegenCallUda(LlvmCodeGen* codegen, const vector<CodegenAnyVal>& input_vals, const CodegenAnyVal& dst, CodegenAnyVal* updated_dst_val) { DCHECK_EQ(evaluator->input_expr_ctxs().size(), input_vals.size()); - const string& symbol = - evaluator->is_merge() ? evaluator->merge_symbol() : evaluator->update_symbol(); - const ColumnType& dst_type = evaluator->intermediate_type(); - - // TODO: to support actual UDAs, not just builtin functions using the UDA interface, - // we need to load the function at this point. - Function* uda_fn = codegen->GetFunction(symbol, true); - DCHECK(uda_fn != NULL); - - vector<FunctionContext::TypeDesc> arg_types; - for (int i = 0; i < evaluator->input_expr_ctxs().size(); ++i) { - arg_types.push_back(AnyValUtil::ColumnTypeToTypeDesc( - evaluator->input_expr_ctxs()[i]->root()->type())); - } - Expr::InlineConstants( - AnyValUtil::ColumnTypeToTypeDesc(dst_type), arg_types, codegen, uda_fn); + Function* uda_fn; + RETURN_IF_ERROR(evaluator->GetUpdateOrMergeFunction(codegen, &uda_fn)); // Set up arguments for call to UDA, which are the FunctionContext*, followed by // pointers to all input values, followed by a pointer to the destination value. @@ -1753,6 +1739,7 @@ Status PartitionedAggregationNode::CodegenCallUda(LlvmCodeGen* codegen, // Create pointer to dst to pass to uda_fn. We must use the unlowered type for the // same reason as above. Value* dst_lowered_ptr = dst.GetLoweredPtr("dst_lowered_ptr"); + const ColumnType& dst_type = evaluator->intermediate_type(); Type* dst_unlowered_ptr_type = CodegenAnyVal::GetUnloweredPtrType(codegen, dst_type); Value* dst_unlowered_ptr = builder->CreateBitCast( dst_lowered_ptr, dst_unlowered_ptr_type, "dst_unlowered_ptr"); @@ -1825,13 +1812,6 @@ Status PartitionedAggregationNode::CodegenUpdateTuple( "intermediate tuple desc"); } - for (AggFnEvaluator* evaluator : aggregate_evaluators_) { - // Don't codegen things that aren't builtins (for now) - if (!evaluator->is_builtin()) { - return Status("PartitionedAggregationNode::CodegenUpdateTuple(): UDA codegen NYI"); - } - } - // Get the types to match the UpdateTuple signature Type* agg_node_type = codegen->GetType(PartitionedAggregationNode::LLVM_CLASS_NAME); Type* fn_ctx_type = codegen->GetType(FunctionContextImpl::LLVM_FUNCTIONCONTEXT_NAME); http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/agg-fn-evaluator.cc ---------------------------------------------------------------------- diff --git a/be/src/exprs/agg-fn-evaluator.cc b/be/src/exprs/agg-fn-evaluator.cc index 4c6a993..b623130 100644 --- a/be/src/exprs/agg-fn-evaluator.cc +++ b/be/src/exprs/agg-fn-evaluator.cc @@ -23,8 +23,9 @@ #include "common/logging.h" #include "exec/aggregation-node.h" #include "exprs/aggregate-functions.h" -#include "exprs/expr-context.h" #include "exprs/anyval-util.h" +#include "exprs/expr-context.h" +#include "exprs/scalar-fn-call.h" #include "runtime/lib-cache.h" #include "runtime/raw-value.h" #include "runtime/runtime-state.h" @@ -94,6 +95,8 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc, bool is_analytic_fn) is_analytic_fn_(is_analytic_fn), intermediate_slot_desc_(NULL), output_slot_desc_(NULL), + arg_type_descs_(AnyValUtil::ColumnTypesToTypeDescs( + ColumnType::FromThrift(desc.agg_expr.arg_types))), cache_entry_(NULL), init_fn_(NULL), update_fn_(NULL), @@ -198,28 +201,15 @@ Status AggFnEvaluator::Prepare(RuntimeState* state, const RowDescriptor& desc, &cache_entry_)); } if (!fn_.aggregate_fn.remove_fn_symbol.empty()) { - RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr( - fn_.hdfs_location, fn_.aggregate_fn.remove_fn_symbol, &remove_fn_, - &cache_entry_)); + RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location, + fn_.aggregate_fn.remove_fn_symbol, &remove_fn_, &cache_entry_)); } if (!fn_.aggregate_fn.finalize_fn_symbol.empty()) { - RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr( - fn_.hdfs_location, fn_.aggregate_fn.finalize_fn_symbol, &finalize_fn_, - &cache_entry_)); - } - - vector<FunctionContext::TypeDesc> arg_types; - for (int i = 0; i < input_expr_ctxs_.size(); ++i) { - arg_types.push_back( - AnyValUtil::ColumnTypeToTypeDesc(input_expr_ctxs_[i]->root()->type())); + RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location, + fn_.aggregate_fn.finalize_fn_symbol, &finalize_fn_, &cache_entry_)); } - - FunctionContext::TypeDesc intermediate_type = - AnyValUtil::ColumnTypeToTypeDesc(intermediate_slot_desc_->type()); - FunctionContext::TypeDesc output_type = - AnyValUtil::ColumnTypeToTypeDesc(output_slot_desc_->type()); - *agg_fn_ctx = FunctionContextImpl::CreateContext( - state, agg_fn_pool, intermediate_type, output_type, arg_types); + *agg_fn_ctx = FunctionContextImpl::CreateContext(state, agg_fn_pool, + GetIntermediateTypeDesc(), GetOutputTypeDesc(), arg_type_descs_); return Status::OK(); } @@ -521,6 +511,40 @@ void AggFnEvaluator::SerializeOrFinalize(FunctionContext* agg_fn_ctx, Tuple* src } } +/// Gets the update or merge function for this UDA. +Status AggFnEvaluator::GetUpdateOrMergeFunction(LlvmCodeGen* codegen, Function** uda_fn) { + const string& symbol = + is_merge_ ? fn_.aggregate_fn.merge_fn_symbol : fn_.aggregate_fn.update_fn_symbol; + vector<ColumnType> fn_arg_types; + for (ExprContext* input_expr_ctx : input_expr_ctxs_) { + fn_arg_types.push_back(input_expr_ctx->root()->type()); + } + // The intermediate value is passed as the last argument. + fn_arg_types.push_back(intermediate_type()); + RETURN_IF_ERROR(codegen->LoadFunction(fn_, symbol, NULL, fn_arg_types, + fn_arg_types.size(), false, uda_fn, &cache_entry_)); + + // Inline constants into the function body (if there is an IR body). + if (!(*uda_fn)->isDeclaration()) { + // TODO: IMPALA-4785: we should also replace references to GetIntermediateType() + // with constants. + Expr::InlineConstants(GetOutputTypeDesc(), arg_type_descs_, codegen, *uda_fn); + *uda_fn = codegen->FinalizeFunction(*uda_fn); + if (*uda_fn == NULL) { + return Status(TErrorCode::UDF_VERIFY_FAILED, symbol, fn_.hdfs_location); + } + } + return Status::OK(); +} + +FunctionContext::TypeDesc AggFnEvaluator::GetIntermediateTypeDesc() const { + return AnyValUtil::ColumnTypeToTypeDesc(intermediate_slot_desc_->type()); +} + +FunctionContext::TypeDesc AggFnEvaluator::GetOutputTypeDesc() const { + return AnyValUtil::ColumnTypeToTypeDesc(output_slot_desc_->type()); +} + string AggFnEvaluator::DebugString(const vector<AggFnEvaluator*>& exprs) { stringstream out; out << "["; http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/agg-fn-evaluator.h ---------------------------------------------------------------------- diff --git a/be/src/exprs/agg-fn-evaluator.h b/be/src/exprs/agg-fn-evaluator.h index b3ecda0..712bf40 100644 --- a/be/src/exprs/agg-fn-evaluator.h +++ b/be/src/exprs/agg-fn-evaluator.h @@ -119,8 +119,6 @@ class AggFnEvaluator { bool SupportsRemove() const { return remove_fn_ != NULL; } bool SupportsSerialize() const { return serialize_fn_ != NULL; } const std::string& fn_name() const { return fn_.name.function_name; } - const std::string& update_symbol() const { return fn_.aggregate_fn.update_fn_symbol; } - const std::string& merge_symbol() const { return fn_.aggregate_fn.merge_fn_symbol; } const SlotDescriptor* output_slot_desc() const { return output_slot_desc_; } static std::string DebugString(const std::vector<AggFnEvaluator*>& exprs); @@ -168,14 +166,8 @@ class AggFnEvaluator { static void Finalize(const std::vector<AggFnEvaluator*>& evaluators, const std::vector<FunctionContext*>& fn_ctxs, Tuple* src, Tuple* dst); - /// TODO: implement codegen path. These functions would return IR functions with - /// the same signature as the interpreted ones above. - /// Function* GetIrInitFn(); - /// Function* GetIrAddFn(); - /// Function* GetIrRemoveFn(); - /// Function* GetIrSerializeFn(); - /// Function* GetIrGetValueFn(); - /// Function* GetIrFinalizeFn(); + /// Gets the codegened update or merge function for this aggregate function. + Status GetUpdateOrMergeFunction(LlvmCodeGen* codegen, llvm::Function** uda_fn); private: const TFunction fn_; @@ -195,6 +187,9 @@ class AggFnEvaluator { /// expression (e.g. count(*)). std::vector<ExprContext*> input_expr_ctxs_; + /// The types of the arguments to the aggregate function. + const std::vector<FunctionContext::TypeDesc> arg_type_descs_; + /// The enum for some of the builtins that still require special cased logic. AggregationOp agg_op_; @@ -221,6 +216,12 @@ class AggFnEvaluator { /// Use Create() instead. AggFnEvaluator(const TExprNode& desc, bool is_analytic_fn); + /// Return the intermediate type of the aggregate function. + FunctionContext::TypeDesc GetIntermediateTypeDesc() const; + + /// Return the output type of the aggregate function. + FunctionContext::TypeDesc GetOutputTypeDesc() const; + /// TODO: these functions below are not extensible and we need to use codegen to /// generate the calls into the UDA functions (like for UDFs). /// Remove these functions when this is supported. http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/anyval-util.cc ---------------------------------------------------------------------- diff --git a/be/src/exprs/anyval-util.cc b/be/src/exprs/anyval-util.cc index 132d6e4..c49cdb3 100644 --- a/be/src/exprs/anyval-util.cc +++ b/be/src/exprs/anyval-util.cc @@ -92,17 +92,33 @@ FunctionContext::TypeDesc AnyValUtil::ColumnTypeToTypeDesc(const ColumnType& typ return out; } +vector<FunctionContext::TypeDesc> AnyValUtil::ColumnTypesToTypeDescs( + const vector<ColumnType>& types) { + vector<FunctionContext::TypeDesc> type_descs; + for (const ColumnType& type : types) type_descs.push_back(ColumnTypeToTypeDesc(type)); + return type_descs; +} + ColumnType AnyValUtil::TypeDescToColumnType(const FunctionContext::TypeDesc& type) { switch (type.type) { - case FunctionContext::TYPE_BOOLEAN: return ColumnType(TYPE_BOOLEAN); - case FunctionContext::TYPE_TINYINT: return ColumnType(TYPE_TINYINT); - case FunctionContext::TYPE_SMALLINT: return ColumnType(TYPE_SMALLINT); - case FunctionContext::TYPE_INT: return ColumnType(TYPE_INT); - case FunctionContext::TYPE_BIGINT: return ColumnType(TYPE_BIGINT); - case FunctionContext::TYPE_FLOAT: return ColumnType(TYPE_FLOAT); - case FunctionContext::TYPE_DOUBLE: return ColumnType(TYPE_DOUBLE); - case FunctionContext::TYPE_TIMESTAMP: return ColumnType(TYPE_TIMESTAMP); - case FunctionContext::TYPE_STRING: return ColumnType(TYPE_STRING); + case FunctionContext::TYPE_BOOLEAN: + return ColumnType(TYPE_BOOLEAN); + case FunctionContext::TYPE_TINYINT: + return ColumnType(TYPE_TINYINT); + case FunctionContext::TYPE_SMALLINT: + return ColumnType(TYPE_SMALLINT); + case FunctionContext::TYPE_INT: + return ColumnType(TYPE_INT); + case FunctionContext::TYPE_BIGINT: + return ColumnType(TYPE_BIGINT); + case FunctionContext::TYPE_FLOAT: + return ColumnType(TYPE_FLOAT); + case FunctionContext::TYPE_DOUBLE: + return ColumnType(TYPE_DOUBLE); + case FunctionContext::TYPE_TIMESTAMP: + return ColumnType(TYPE_TIMESTAMP); + case FunctionContext::TYPE_STRING: + return ColumnType(TYPE_STRING); case FunctionContext::TYPE_DECIMAL: return ColumnType::CreateDecimalType(type.precision, type.scale); case FunctionContext::TYPE_FIXED_BUFFER: http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/anyval-util.h ---------------------------------------------------------------------- diff --git a/be/src/exprs/anyval-util.h b/be/src/exprs/anyval-util.h index e5473c7..429322a 100644 --- a/be/src/exprs/anyval-util.h +++ b/be/src/exprs/anyval-util.h @@ -227,6 +227,8 @@ class AnyValUtil { } static FunctionContext::TypeDesc ColumnTypeToTypeDesc(const ColumnType& type); + static std::vector<FunctionContext::TypeDesc> ColumnTypesToTypeDescs( + const std::vector<ColumnType>& types); // Note: constructing a ColumnType is expensive and should be avoided in query execution // paths (i.e. non-setup paths). static ColumnType TypeDescToColumnType(const FunctionContext::TypeDesc& type); http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/expr.cc ---------------------------------------------------------------------- diff --git a/be/src/exprs/expr.cc b/be/src/exprs/expr.cc index 2ffddf6..f119a9d 100644 --- a/be/src/exprs/expr.cc +++ b/be/src/exprs/expr.cc @@ -637,10 +637,10 @@ int Expr::InlineConstants(LlvmCodeGen* codegen, Function* fn) { } int Expr::InlineConstants(const FunctionContext::TypeDesc& return_type, - const std::vector<FunctionContext::TypeDesc>& arg_types, LlvmCodeGen* codegen, - Function* fn) { + const std::vector<FunctionContext::TypeDesc>& arg_types, LlvmCodeGen* codegen, + Function* fn) { int replaced = 0; - for (inst_iterator iter = inst_begin(fn), end = inst_end(fn); iter != end; ) { + for (inst_iterator iter = inst_begin(fn), end = inst_end(fn); iter != end;) { // Increment iter now so we don't mess it up modifying the instruction below Instruction* instr = &*(iter++); @@ -666,7 +666,7 @@ int Expr::InlineConstants(const FunctionContext::TypeDesc& return_type, int i_val = static_cast<int>(i_arg->getSExtValue()); // All supported constants are currently integers. call_instr->replaceAllUsesWith(ConstantInt::get(codegen->GetType(TYPE_INT), - GetConstantInt(return_type, arg_types, c_val, i_val))); + GetConstantInt(return_type, arg_types, c_val, i_val))); call_instr->eraseFromParent(); ++replaced; } http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/expr.h ---------------------------------------------------------------------- diff --git a/be/src/exprs/expr.h b/be/src/exprs/expr.h index b77c9a2..13ba312 100644 --- a/be/src/exprs/expr.h +++ b/be/src/exprs/expr.h @@ -263,9 +263,11 @@ class Expr { // Any additions to this enum must be reflected in both GetConstant*() and // GetIrConstant(). enum ExprConstant { + // RETURN_TYPE_*: properties of FunctionContext::GetReturnType(). RETURN_TYPE_SIZE, // int RETURN_TYPE_PRECISION, // int RETURN_TYPE_SCALE, // int + // ARG_TYPE_* with parameter i: properties of FunctionContext::GetArgType(i). ARG_TYPE_SIZE, // int[] ARG_TYPE_PRECISION, // int[] ARG_TYPE_SCALE, // int[] @@ -289,7 +291,8 @@ class Expr { // constants to be replaced must be inlined into the function that InlineConstants() // is run on (e.g. by annotating them with IR_ALWAYS_INLINE). // - // TODO: implement a loop unroller (or use LLVM's) so we can use GetConstantInt() in loops + // TODO: implement a loop unroller (or use LLVM's) so we can use GetConstantInt() in + // loops static int GetConstantInt(const FunctionContext& ctx, ExprConstant c, int i = -1); /// Finds all calls to Expr::GetConstantInt() in 'fn' and replaces them with the @@ -298,8 +301,8 @@ class Expr { /// 'arg_types' are the argument types of the UDF or UDAF, i.e. the values of /// FunctionContext::GetArgType(). static int InlineConstants(const FunctionContext::TypeDesc& return_type, - const std::vector<FunctionContext::TypeDesc>& arg_types, - LlvmCodeGen* codegen, llvm::Function* fn); + const std::vector<FunctionContext::TypeDesc>& arg_types, LlvmCodeGen* codegen, + llvm::Function* fn); static const char* LLVM_CLASS_NAME; http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/scalar-fn-call.cc ---------------------------------------------------------------------- diff --git a/be/src/exprs/scalar-fn-call.cc b/be/src/exprs/scalar-fn-call.cc index c7d4e7e..06830b7 100644 --- a/be/src/exprs/scalar-fn-call.cc +++ b/be/src/exprs/scalar-fn-call.cc @@ -19,12 +19,12 @@ #include <vector> #include <gutil/strings/substitute.h> -#include <llvm/IR/Attributes.h> #include <llvm/ExecutionEngine/ExecutionEngine.h> +#include <llvm/IR/Attributes.h> #include <boost/preprocessor/punctuation/comma_if.hpp> -#include <boost/preprocessor/repetition/repeat.hpp> #include <boost/preprocessor/repetition/enum_params.hpp> +#include <boost/preprocessor/repetition/repeat.hpp> #include <boost/preprocessor/repetition/repeat_from_to.hpp> #include "codegen/codegen-anyval.h" @@ -37,8 +37,6 @@ #include "runtime/types.h" #include "udf/udf-internal.h" #include "util/debug-util.h" -#include "util/dynamic-util.h" -#include "util/symbols-util.h" #include "common/names.h" @@ -311,21 +309,28 @@ Status ScalarFnCall::GetCodegendComputeFn(LlvmCodeGen* codegen, Function** fn) { } } + vector<ColumnType> arg_types; + for (const Expr* child : children_) arg_types.push_back(child->type()); + Function* udf; + RETURN_IF_ERROR(codegen->LoadFunction(fn_, fn_.scalar_fn.symbol, &type_, arg_types, + NumFixedArgs(), vararg_start_idx_ != -1, &udf, &cache_entry_)); + // Inline constants into the function if it has an IR body. + if (!udf->isDeclaration()) { + InlineConstants(AnyValUtil::ColumnTypeToTypeDesc(type_), + AnyValUtil::ColumnTypesToTypeDescs(arg_types), codegen, udf); + udf = codegen->FinalizeFunction(udf); + if (udf == NULL) { + return Status( + TErrorCode::UDF_VERIFY_FAILED, fn_.scalar_fn.symbol, fn_.hdfs_location); + } + } + if (fn_.binary_type == TFunctionBinaryType::IR) { - string local_path; - RETURN_IF_ERROR(LibCache::instance()->GetLocalLibPath( - fn_.hdfs_location, LibCache::TYPE_IR, &local_path)); - // Link the UDF module into this query's main module (essentially copy the UDF - // module into the main module) so the UDF's functions are available in the main - // module. - RETURN_IF_ERROR(codegen->LinkModule(local_path)); - // Load the Prepare() and Close() functions from the LLVM module. + // LoadFunction() should have linked the IR module into 'codegen'. Now load the + // Prepare() and Close() functions from 'codegen'. RETURN_IF_ERROR(LoadPrepareAndCloseFn(codegen)); } - Function* udf; - RETURN_IF_ERROR(GetUdf(codegen, &udf)); - // Create wrapper that computes args and calls UDF stringstream fn_name; fn_name << udf->getName().str() << "Wrapper"; @@ -407,8 +412,7 @@ Status ScalarFnCall::GetCodegendComputeFn(LlvmCodeGen* codegen, Function** fn) { // Add the number of varargs udf_args.push_back(codegen->GetIntConstant(TYPE_INT, NumVarArgs())); // Add all the accumulated vararg inputs as one input argument. - PointerType* vararg_type = - codegen->GetPtrType(CodegenAnyVal::GetUnloweredType(codegen, VarArgsType())); + PointerType* vararg_type = CodegenAnyVal::GetUnloweredPtrType(codegen, VarArgsType()); udf_args.push_back(builder.CreateBitCast(varargs_buffer, vararg_type, "varargs")); } @@ -428,119 +432,11 @@ Status ScalarFnCall::GetCodegendComputeFn(LlvmCodeGen* codegen, Function** fn) { return Status::OK(); } -Status ScalarFnCall::GetUdf(LlvmCodeGen* codegen, Function** udf) { - // from_utc_timestamp() and to_utc_timestamp() have inline ASM that cannot be JIT'd. - // TimestampFunctions::AddSub() contains a try/catch which doesn't work in JIT'd - // code. Always use the interpreted version of these functions. - // TODO: fix these built-in functions so we don't need 'broken_builtin' below. - bool broken_builtin = fn_.name.function_name == "from_utc_timestamp" || - fn_.name.function_name == "to_utc_timestamp" || - fn_.scalar_fn.symbol.find("AddSub") != string::npos; - if (fn_.binary_type == TFunctionBinaryType::NATIVE || - (fn_.binary_type == TFunctionBinaryType::BUILTIN && broken_builtin)) { - // In this path, we are code that has been statically compiled to assembly. - // This can either be a UDF implemented in a .so or a builtin using the UDF - // interface with the code in impalad. - void* fn_ptr; - Status status = LibCache::instance()->GetSoFunctionPtr( - fn_.hdfs_location, fn_.scalar_fn.symbol, &fn_ptr, &cache_entry_); - if (!status.ok() && fn_.binary_type == TFunctionBinaryType::BUILTIN) { - // Builtins symbols should exist unless there is a version mismatch. - status.AddDetail(ErrorMsg(TErrorCode::MISSING_BUILTIN, - fn_.name.function_name, fn_.scalar_fn.symbol).msg()); - } - RETURN_IF_ERROR(status); - DCHECK(fn_ptr != NULL); - - // Per the x64 ABI, DecimalVals are returned via a DecmialVal* output argument. - // So, the return type is void. - bool is_decimal = type().type == TYPE_DECIMAL; - Type* return_type = is_decimal ? codegen->void_type() : - CodegenAnyVal::GetLoweredType(codegen, type()); - - // Convert UDF function pointer to Function*. Start by creating a function - // prototype for it. - LlvmCodeGen::FnPrototype prototype(codegen, fn_.scalar_fn.symbol, return_type); - - if (is_decimal) { - // Per the x64 ABI, DecimalVals are returned via a DecmialVal* output argument - Type* output_type = - codegen->GetPtrType(CodegenAnyVal::GetUnloweredType(codegen, type())); - prototype.AddArgument("output", output_type); - } - - // The "FunctionContext*" argument. - prototype.AddArgument("ctx", - codegen->GetPtrType("class.impala_udf::FunctionContext")); - - // The "fixed" arguments for the UDF function. - for (int i = 0; i < NumFixedArgs(); ++i) { - stringstream arg_name; - arg_name << "fixed_arg_" << i; - Type* arg_type = codegen->GetPtrType( - CodegenAnyVal::GetUnloweredType(codegen, children_[i]->type())); - prototype.AddArgument(arg_name.str(), arg_type); - } - // The varargs for the UDF function if there is any. - if (NumVarArgs() > 0) { - Type* vararg_type = CodegenAnyVal::GetUnloweredPtrType( - codegen, children_[vararg_start_idx_]->type()); - prototype.AddArgument("num_var_arg", codegen->GetType(TYPE_INT)); - prototype.AddArgument("var_arg", vararg_type); - } - - // Create a Function* with the generated type. This is only a function - // declaration, not a definition, since we do not create any basic blocks or - // instructions in it. - *udf = prototype.GeneratePrototype(NULL, NULL, false); - - // Associate the dynamically loaded function pointer with the Function* we defined. - // This tells LLVM where the compiled function definition is located in memory. - codegen->execution_engine()->addGlobalMapping(*udf, fn_ptr); - } else if (fn_.binary_type == TFunctionBinaryType::BUILTIN) { - // In this path, we're running a builtin with the UDF interface. The IR is - // in the llvm module. - *udf = codegen->GetFunction(fn_.scalar_fn.symbol, false); - if (*udf == NULL) { - // Builtins symbols should exist unless there is a version mismatch. - stringstream ss; - ss << "Builtin '" << fn_.name.function_name << "' with symbol '" - << fn_.scalar_fn.symbol << "' does not exist. " - << "Verify that all your impalads are the same version."; - return Status(ss.str()); - } - // Builtin functions may use Expr::GetConstant(). Clone the function in case we need - // to use it again, and rename it to something more manageable than the mangled name. - string demangled_name = SymbolsUtil::DemangleNoArgs((*udf)->getName().str()); - *udf = codegen->CloneFunction(*udf); - (*udf)->setName(demangled_name); - InlineConstants(codegen, *udf); - *udf = codegen->FinalizeFunction(*udf); - DCHECK(*udf != NULL); - } else { - // We're running an IR UDF. - DCHECK_EQ(fn_.binary_type, TFunctionBinaryType::IR); - *udf = codegen->GetFunction(fn_.scalar_fn.symbol, false); - if (*udf == NULL) { - stringstream ss; - ss << "Unable to locate function " << fn_.scalar_fn.symbol << " from LLVM module " - << fn_.hdfs_location; - return Status(ss.str()); - } - *udf = codegen->FinalizeFunction(*udf); - if (*udf == NULL) { - return Status( - TErrorCode::UDF_VERIFY_FAILED, fn_.scalar_fn.symbol, fn_.hdfs_location); - } - } - return Status::OK(); -} - Status ScalarFnCall::GetFunction(LlvmCodeGen* codegen, const string& symbol, void** fn) { - if (fn_.binary_type == TFunctionBinaryType::NATIVE || - fn_.binary_type == TFunctionBinaryType::BUILTIN) { - return LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location, symbol, fn, - &cache_entry_); + if (fn_.binary_type == TFunctionBinaryType::NATIVE + || fn_.binary_type == TFunctionBinaryType::BUILTIN) { + return LibCache::instance()->GetSoFunctionPtr( + fn_.hdfs_location, symbol, fn, &cache_entry_); } else { DCHECK_EQ(fn_.binary_type, TFunctionBinaryType::IR); DCHECK(codegen != NULL); http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/scalar-fn-call.h ---------------------------------------------------------------------- diff --git a/be/src/exprs/scalar-fn-call.h b/be/src/exprs/scalar-fn-call.h index c8bc8c8..ffb9a2f 100644 --- a/be/src/exprs/scalar-fn-call.h +++ b/be/src/exprs/scalar-fn-call.h @@ -48,7 +48,7 @@ class TExprNode; /// - Test cancellation /// - Type descs in UDA test harness /// - Allow more functions to be NULL in UDA test harness -class ScalarFnCall: public Expr { +class ScalarFnCall : public Expr { public: virtual std::string DebugString() const; @@ -117,9 +117,6 @@ class ScalarFnCall: public Expr { return children_.back()->type(); } - /// Loads the native or IR function from HDFS and puts the result in *udf. - Status GetUdf(LlvmCodeGen* codegen, llvm::Function** udf); - /// Loads the native or IR function 'symbol' from HDFS and puts the result in *fn. /// If the function is loaded from an IR module, it cannot be called until the module /// has been JIT'd (i.e. after GetCodegendComputeFn() has been called). http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/timestamp-functions.cc ---------------------------------------------------------------------- diff --git a/be/src/exprs/timestamp-functions.cc b/be/src/exprs/timestamp-functions.cc index a519696..a63438c 100644 --- a/be/src/exprs/timestamp-functions.cc +++ b/be/src/exprs/timestamp-functions.cc @@ -76,7 +76,7 @@ void ThrowIfDateOutOfRange(const boost::gregorian::date& date) { // This function uses inline asm functions, which we believe to be from the boost library. // Inline asm is not currently supported by JIT, so this function should always be run in -// the interpreted mode. This is handled in ScalarFnCall::GetUdf(). +// the interpreted mode. This is handled in LlvmCodeGen::LoadFunction(). TimestampVal TimestampFunctions::FromUtc(FunctionContext* context, const TimestampVal& ts_val, const StringVal& tz_string_val) { if (ts_val.is_null || tz_string_val.is_null) return TimestampVal::null(); @@ -114,7 +114,7 @@ TimestampVal TimestampFunctions::FromUtc(FunctionContext* context, // This function uses inline asm functions, which we believe to be from the boost library. // Inline asm is not currently supported by JIT, so this function should always be run in -// the interpreted mode. This is handled in ScalarFnCall::GetUdf(). +// the interpreted mode. This is handled in LlvmCodeGen::LoadFunction(). TimestampVal TimestampFunctions::ToUtc(FunctionContext* context, const TimestampVal& ts_val, const StringVal& tz_string_val) { if (ts_val.is_null || tz_string_val.is_null) return TimestampVal::null(); http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/runtime/types.cc ---------------------------------------------------------------------- diff --git a/be/src/runtime/types.cc b/be/src/runtime/types.cc index 3a04ca3..f580628 100644 --- a/be/src/runtime/types.cc +++ b/be/src/runtime/types.cc @@ -310,6 +310,12 @@ string ColumnType::DebugString() const { } } +vector<ColumnType> ColumnType::FromThrift(const vector<TColumnType>& ttypes) { + vector<ColumnType> types; + for (const TColumnType& ttype : ttypes) types.push_back(FromThrift(ttype)); + return types; +} + ostream& operator<<(ostream& os, const ColumnType& type) { os << type.DebugString(); return os; http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/runtime/types.h ---------------------------------------------------------------------- diff --git a/be/src/runtime/types.h b/be/src/runtime/types.h index f265705..f2db3bd 100644 --- a/be/src/runtime/types.h +++ b/be/src/runtime/types.h @@ -147,6 +147,8 @@ struct ColumnType { return result; } + static std::vector<ColumnType> FromThrift(const std::vector<TColumnType>& ttypes); + bool operator==(const ColumnType& o) const { if (type != o.type) return false; if (children != o.children) return false; http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/testutil/test-udas.cc ---------------------------------------------------------------------- diff --git a/be/src/testutil/test-udas.cc b/be/src/testutil/test-udas.cc index 0097500..549f2f0 100644 --- a/be/src/testutil/test-udas.cc +++ b/be/src/testutil/test-udas.cc @@ -17,6 +17,9 @@ #include "testutil/test-udas.h" +#include <assert.h> + +// Don't include Impala internal headers - real UDAs won't include them. #include <udf/udf.h> using namespace impala_udf; @@ -48,19 +51,95 @@ void Agg(FunctionContext*, const StringVal&, const DoubleVal&, StringVal*) {} void AggInit(FunctionContext*, StringVal*){} void AggMerge(FunctionContext*, const StringVal&, StringVal*) {} StringVal AggSerialize(FunctionContext*, const StringVal& v) { return v;} -StringVal AggFinalize(FunctionContext*, const StringVal& v) { return v;} - +StringVal AggFinalize(FunctionContext*, const StringVal& v) { + return v; +} -// Defines AggIntermediate(int) returns BIGINT intermediate CHAR(10) -// TODO: StringVal should be replaced with BufferVal in Impala 2.0 -void AggIntermediate(FunctionContext*, const IntVal&, StringVal*) {} -void AggIntermediateUpdate(FunctionContext*, const IntVal&, StringVal*) {} -void AggIntermediateInit(FunctionContext*, StringVal*) {} -void AggIntermediateMerge(FunctionContext*, const StringVal&, StringVal*) {} -BigIntVal AggIntermediateFinalize(FunctionContext*, const StringVal&) { +// Defines AggIntermediate(int) returns BIGINT intermediate STRING +void AggIntermediate(FunctionContext* context, const IntVal&, StringVal*) {} +void AggIntermediateUpdate(FunctionContext* context, const IntVal&, StringVal*) { + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); +} +void AggIntermediateInit(FunctionContext* context, StringVal*) { + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); +} +void AggIntermediateMerge(FunctionContext* context, const StringVal&, StringVal*) { + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); +} +BigIntVal AggIntermediateFinalize(FunctionContext* context, const StringVal&) { + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); return BigIntVal::null(); } +// Defines AggDecimalIntermediate(DECIMAL(1,2), INT) returns DECIMAL(5,6) +// intermediate DECIMAL(3,4) +// Useful to test that type parameters are plumbed through. +void AggDecimalIntermediateUpdate(FunctionContext* context, const DecimalVal&, const IntVal&, DecimalVal*) { + assert(context->GetNumArgs() == 2); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); + assert(context->GetArgType(0)->precision == 2); + assert(context->GetArgType(0)->scale == 1); + assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetIntermediateType().precision == 4); + assert(context->GetIntermediateType().scale == 3); + assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetReturnType().precision == 6); + assert(context->GetReturnType().scale == 5); +} +void AggDecimalIntermediateInit(FunctionContext* context, DecimalVal*) { + assert(context->GetNumArgs() == 2); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); + assert(context->GetArgType(0)->precision == 2); + assert(context->GetArgType(0)->scale == 1); + assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetIntermediateType().precision == 4); + assert(context->GetIntermediateType().scale == 3); + assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetReturnType().precision == 6); + assert(context->GetReturnType().scale == 5); +} +void AggDecimalIntermediateMerge(FunctionContext* context, const DecimalVal&, DecimalVal*) { + assert(context->GetNumArgs() == 2); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); + assert(context->GetArgType(0)->precision == 2); + assert(context->GetArgType(0)->scale == 1); + assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetIntermediateType().precision == 4); + assert(context->GetIntermediateType().scale == 3); + assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetReturnType().precision == 6); + assert(context->GetReturnType().scale == 5); +} +DecimalVal AggDecimalIntermediateFinalize(FunctionContext* context, const DecimalVal&) { + assert(context->GetNumArgs() == 2); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL); + assert(context->GetArgType(0)->precision == 2); + assert(context->GetArgType(0)->scale == 1); + assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetIntermediateType().precision == 4); + assert(context->GetIntermediateType().scale == 3); + assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL); + assert(context->GetReturnType().precision == 6); + assert(context->GetReturnType().scale == 5); + return DecimalVal::null(); +} + // Defines MemTest(bigint) return bigint // "Allocates" the specified number of bytes in the update function and frees them in the // serialize function. Useful for testing mem limits. @@ -99,22 +178,57 @@ BigIntVal MemTestFinalize(FunctionContext* context, const BigIntVal& total) { // Defines aggregate function for testing different intermediate/output types that // computes the truncated bigint sum of many floats. void TruncSumInit(FunctionContext* context, DoubleVal* total) { + // Arg types should be logical input types of UDA. + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE); + assert(context->GetArgType(-1) == nullptr); + assert(context->GetArgType(1) == nullptr); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); *total = DoubleVal(0); } void TruncSumUpdate(FunctionContext* context, const DoubleVal& val, DoubleVal* total) { + // Arg types should be logical input types of UDA. + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE); + assert(context->GetArgType(-1) == nullptr); + assert(context->GetArgType(1) == nullptr); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); total->val += val.val; } void TruncSumMerge(FunctionContext* context, const DoubleVal& src, DoubleVal* dst) { + // Arg types should be logical input types of UDA. + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE); + assert(context->GetArgType(-1) == nullptr); + assert(context->GetArgType(1) == nullptr); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); dst->val += src.val; } const DoubleVal TruncSumSerialize(FunctionContext* context, const DoubleVal& total) { + // Arg types should be logical input types of UDA. + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE); + assert(context->GetArgType(-1) == nullptr); + assert(context->GetArgType(1) == nullptr); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); return total; } BigIntVal TruncSumFinalize(FunctionContext* context, const DoubleVal& total) { + // Arg types should be logical input types of UDA. + assert(context->GetNumArgs() == 1); + assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE); + assert(context->GetArgType(-1) == nullptr); + assert(context->GetArgType(1) == nullptr); + assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE); + assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT); return BigIntVal(static_cast<int64_t>(total.val)); } http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/testutil/test-udas.h ---------------------------------------------------------------------- diff --git a/be/src/testutil/test-udas.h b/be/src/testutil/test-udas.h index 8cd38ec..57b0f06 100644 --- a/be/src/testutil/test-udas.h +++ b/be/src/testutil/test-udas.h @@ -18,6 +18,7 @@ #ifndef IMPALA_UDF_TEST_UDAS_H #define IMPALA_UDF_TEST_UDAS_H +// Don't include Impala internal headers - real UDAs won't include them. #include "udf/udf.h" using namespace impala_udf; http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/udf/udf-internal.h ---------------------------------------------------------------------- diff --git a/be/src/udf/udf-internal.h b/be/src/udf/udf-internal.h index bf3032e..96d0fc7 100644 --- a/be/src/udf/udf-internal.h +++ b/be/src/udf/udf-internal.h @@ -47,6 +47,12 @@ class RuntimeState; /// This class actually implements the interface of FunctionContext. This is split to /// hide the details from the external header. /// Note: The actual user code does not include this file. +/// +/// Exprs (e.g. UDFs and UDAs) require a FunctionContext to store state related to +/// evaluation of the expression. Each FunctionContext is associated with a backend Expr +/// or AggFnEvaluator, which is derived from a TExprNode generated by the Impala frontend. +/// FunctionContexts are allocated and managed by ExprContext. Exprs shouldn't try to +/// create FunctionContext themselves. class FunctionContextImpl { public: /// Create a FunctionContext for a UDF. Caller is responsible for deleting it. http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/udf/udf-ir.cc ---------------------------------------------------------------------- diff --git a/be/src/udf/udf-ir.cc b/be/src/udf/udf-ir.cc index 24773f0..c12133c 100644 --- a/be/src/udf/udf-ir.cc +++ b/be/src/udf/udf-ir.cc @@ -34,6 +34,10 @@ int FunctionContext::GetNumArgs() const { return impl_->arg_types_.size(); } +const FunctionContext::TypeDesc& FunctionContext::GetIntermediateType() const { + return impl_->intermediate_type_; +} + const FunctionContext::TypeDesc& FunctionContext::GetReturnType() const { return impl_->return_type_; } http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/udf/udf.h ---------------------------------------------------------------------- diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h index 461482c..4fdca2d 100644 --- a/be/src/udf/udf.h +++ b/be/src/udf/udf.h @@ -189,21 +189,25 @@ class FunctionContext { const TypeDesc& GetIntermediateType() const; /// Returns the number of arguments to this function (not including the FunctionContext* - /// argument). + /// argument or the output of a UDA). + /// For UDAs, returns the number of logical arguments of the aggregate function, not + /// the number of arguments of the C++ function being executed. int GetNumArgs() const; /// Returns the type information for the arg_idx-th argument (0-indexed, not including /// the FunctionContext* argument). Returns NULL if arg_idx is invalid. + /// For UDAs, returns the logical argument types of the aggregate function, not the + /// argument types of the C++ function being executed. const TypeDesc* GetArgType(int arg_idx) const; - /// Returns true if the arg_idx-th input argument (0 indexed, not including the - /// FunctionContext* argument) is a constant (e.g. 5, "string", 1 + 1). + /// Returns true if the arg_idx-th input argument (indexed in the same way as + /// GetArgType()) is a constant (e.g. 5, "string", 1 + 1). bool IsArgConstant(int arg_idx) const; - /// Returns a pointer to the value of the arg_idx-th input argument (0 indexed, not - /// including the FunctionContext* argument). Returns NULL if the argument is not - /// constant. This function can be used to obtain user-specified constants in a UDF's - /// Init() or Close() functions. + /// Returns a pointer to the value of the arg_idx-th input argument (indexed in the + /// same way as GetArgType()). Returns NULL if the argument is not constant. This + /// function can be used to obtain user-specified constants in a UDF's Init() or + /// Close() functions. AnyVal* GetConstantArg(int arg_idx) const; /// TODO: Do we need to add arbitrary key/value metadata. This would be plumbed http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/common/thrift/Exprs.thrift ---------------------------------------------------------------------- diff --git a/common/thrift/Exprs.thrift b/common/thrift/Exprs.thrift index 33b859a..fc0f4ee 100644 --- a/common/thrift/Exprs.thrift +++ b/common/thrift/Exprs.thrift @@ -112,9 +112,14 @@ struct TStringLiteral { 1: required string value; } +// Additional information for aggregate functions. struct TAggregateExpr { // Indicates whether this expr is the merge() of an aggregation. 1: required bool is_merge_agg + + // The types of the input arguments to the aggregate function. May differ from the + // input expr types if this is the merge() of an aggregation. + 2: required list<Types.TColumnType> arg_types; } // This is essentially a union over the subclasses of Expr. http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java ---------------------------------------------------------------------- diff --git a/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java b/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java index ec88aae..5dbde32 100644 --- a/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java +++ b/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java @@ -673,7 +673,8 @@ public class AggregateInfo extends AggregateInfoBase { * materialized slots of the output tuple corresponds to the number of materialized * aggregate functions plus the number of grouping exprs. Also checks that the return * types of the aggregate and grouping exprs correspond to the slots in the output - * tuple. + * tuple and that the input types stored in the merge aggregation are consistent + * with the input exprs. */ public void checkConsistency() { ArrayList<SlotDescriptor> slots = outputTupleDesc_.getSlots(); @@ -707,6 +708,13 @@ public class AggregateInfo extends AggregateInfoBase { slotType.toString())); ++slotIdx; } + if (mergeAggInfo_ != null) { + // Check that the argument types in mergeAggInfo_ are consistent with input exprs. + for (int i = 0; i < aggregateExprs_.size(); ++i) { + FunctionCallExpr mergeAggExpr = mergeAggInfo_.aggregateExprs_.get(i); + mergeAggExpr.validateMergeAggFn(aggregateExprs_.get(i)); + } + } } /** http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java ---------------------------------------------------------------------- diff --git a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java index 15af25f..4d7dca8 100644 --- a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java +++ b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java @@ -30,6 +30,7 @@ import org.apache.impala.catalog.Type; import org.apache.impala.common.AnalysisException; import org.apache.impala.common.TreeNode; import org.apache.impala.thrift.TAggregateExpr; +import org.apache.impala.thrift.TColumnType; import org.apache.impala.thrift.TExprNode; import org.apache.impala.thrift.TExprNodeType; import org.apache.impala.thrift.TFunctionBinaryType; @@ -45,10 +46,12 @@ public class FunctionCallExpr extends Expr { private boolean isAnalyticFnCall_ = false; private boolean isInternalFnCall_ = false; - // Indicates whether this is a merge aggregation function that should use the merge - // instead of the update symbol. This flag also affects the behavior of - // resetAnalysisState() which is used during expr substitution. - private final boolean isMergeAggFn_; + // Non-null iff this is an aggregation function that executes the Merge() step. This + // is an analyzed clone of the FunctionCallExpr that executes the Update() function + // feeding into this Merge(). This is stored so that we can access the types of the + // original input argument exprs. Note that the nullness affects the behaviour of + // resetAnalysisState(), which is used during expr substitution. + private final FunctionCallExpr mergeAggInputFn_; // Printed in toSqlImpl(), if set. Used for merge agg fns. private String label_; @@ -62,15 +65,16 @@ public class FunctionCallExpr extends Expr { } public FunctionCallExpr(FunctionName fnName, FunctionParams params) { - this(fnName, params, false); + this(fnName, params, null); } - private FunctionCallExpr( - FunctionName fnName, FunctionParams params, boolean isMergeAggFn) { + private FunctionCallExpr(FunctionName fnName, FunctionParams params, + FunctionCallExpr mergeAggInputFn) { super(); fnName_ = fnName; params_ = params; - isMergeAggFn_ = isMergeAggFn; + mergeAggInputFn_ = + mergeAggInputFn == null ? null : (FunctionCallExpr)mergeAggInputFn.clone(); if (params.exprs() != null) children_ = Lists.newArrayList(params_.exprs()); } @@ -99,12 +103,12 @@ public class FunctionCallExpr extends Expr { Preconditions.checkState(agg.isAnalyzed()); Preconditions.checkState(agg.isAggregateFunction()); FunctionCallExpr result = new FunctionCallExpr( - agg.fnName_, new FunctionParams(false, params), true); + agg.fnName_, new FunctionParams(false, params), agg); // Inherit the function object from 'agg'. result.fn_ = agg.fn_; result.type_ = agg.type_; // Set an explicit label based on the input agg. - if (agg.isMergeAggFn_) { + if (agg.isMergeAggFn()) { result.label_ = agg.label_; } else { // fn(input) becomes fn:merge(input). @@ -123,7 +127,8 @@ public class FunctionCallExpr extends Expr { fnName_ = other.fnName_; isAnalyticFnCall_ = other.isAnalyticFnCall_; isInternalFnCall_ = other.isInternalFnCall_; - isMergeAggFn_ = other.isMergeAggFn_; + mergeAggInputFn_ = + other.mergeAggInputFn_ == null ? null : (FunctionCallExpr)other.mergeAggInputFn_.clone(); // Clone the params in a way that keeps the children_ and the params.exprs() // in sync. The children have already been cloned in the super c'tor. if (other.params_.isStar()) { @@ -135,7 +140,7 @@ public class FunctionCallExpr extends Expr { label_ = other.label_; } - public boolean isMergeAggFn() { return isMergeAggFn_; } + public boolean isMergeAggFn() { return mergeAggInputFn_ != null; } @Override public void resetAnalysisState() { @@ -144,7 +149,7 @@ public class FunctionCallExpr extends Expr { // intermediate agg type is not the same as the output type. Preserve the original // fn_ such that analyze() hits the special-case code for merge agg fns that // handles this case. - if (!isMergeAggFn_) fn_ = null; + if (!isMergeAggFn()) fn_ = null; } @Override @@ -160,7 +165,7 @@ public class FunctionCallExpr extends Expr { public String toSqlImpl() { if (label_ != null) return label_; // Merge agg fns should have an explicit label. - Preconditions.checkState(!isMergeAggFn_); + Preconditions.checkState(!isMergeAggFn()); StringBuilder sb = new StringBuilder(); sb.append(fnName_).append("("); if (params_.isStar()) sb.append("*"); @@ -226,7 +231,12 @@ public class FunctionCallExpr extends Expr { protected void toThrift(TExprNode msg) { if (isAggregateFunction() || isAnalyticFnCall_) { msg.node_type = TExprNodeType.AGGREGATE_EXPR; - if (!isAnalyticFnCall_) msg.setAgg_expr(new TAggregateExpr(isMergeAggFn_)); + List<TColumnType> aggFnArgTypes = Lists.newArrayList(); + FunctionCallExpr inputAggFn = isMergeAggFn() ? mergeAggInputFn_ : this; + for (Expr child: inputAggFn.children_) { + aggFnArgTypes.add(child.getType().toThrift()); + } + msg.setAgg_expr(new TAggregateExpr(isMergeAggFn(), aggFnArgTypes)); } else { msg.node_type = TExprNodeType.FUNCTION_CALL; } @@ -383,7 +393,7 @@ public class FunctionCallExpr extends Expr { protected void analyzeImpl(Analyzer analyzer) throws AnalysisException { fnName_.analyze(analyzer); - if (isMergeAggFn_) { + if (isMergeAggFn()) { // This is the function call expr after splitting up to a merge aggregation. // The function has already been analyzed so just do the minimal sanity // check here. @@ -524,6 +534,25 @@ public class FunctionCallExpr extends Expr { } } + /** + * Validate that the internal state, specifically types, is consistent between the + * the Update() and Merge() aggregate functions. + */ + void validateMergeAggFn(FunctionCallExpr inputAggFn) { + Preconditions.checkState(isMergeAggFn()); + List<Expr> copiedInputExprs = mergeAggInputFn_.getChildren(); + List<Expr> inputExprs = inputAggFn.getChildren(); + Preconditions.checkState(copiedInputExprs.size() == inputExprs.size()); + for (int i = 0; i < inputExprs.size(); ++i) { + Type copiedInputType = copiedInputExprs.get(i).getType(); + Type inputType = inputExprs.get(i).getType(); + Preconditions.checkState(copiedInputType.equals(inputType), + String.format("Copied expr %s arg type %s differs from input expr type %s " + + "in original expr %s", toSql(), copiedInputType.toSql(), + inputType.toSql(), inputAggFn.toSql())); + } + } + @Override public Expr clone() { return new FunctionCallExpr(this); } } http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/testdata/workloads/functional-query/queries/QueryTest/uda.test ---------------------------------------------------------------------- diff --git a/testdata/workloads/functional-query/queries/QueryTest/uda.test b/testdata/workloads/functional-query/queries/QueryTest/uda.test index 05e24e7..3a9bbbe 100644 --- a/testdata/workloads/functional-query/queries/QueryTest/uda.test +++ b/testdata/workloads/functional-query/queries/QueryTest/uda.test @@ -69,3 +69,22 @@ from functional.alltypesagg ---- TYPES bigint,bigint ==== +---- QUERY +# Test that all types are exposed via the FunctionContext correctly. +# This relies on asserts in the UDA funciton +select agg_intermediate(int_col), count(*) +from functional.alltypesagg +---- RESULTS +NULL,11000 +---- TYPES +bigint,bigint +==== +---- QUERY +# Test that all types are exposed via the FunctionContext correctly. +# This relies on asserts in the UDA funciton +select agg_decimal_intermediate(cast(d1 as decimal(2,1)), 2), count(*) +from functional.decimal_tbl +---- RESULTS +NULL,5 +---- TYPES +decimal,bigint http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/tests/query_test/test_udfs.py ---------------------------------------------------------------------- diff --git a/tests/query_test/test_udfs.py b/tests/query_test/test_udfs.py index 746cf88..56ce233 100644 --- a/tests/query_test/test_udfs.py +++ b/tests/query_test/test_udfs.py @@ -93,6 +93,16 @@ update_fn='ToggleNullUpdate' merge_fn='ToggleNullMerge'; create aggregate function {database}.count_nulls(bigint) returns bigint location '{location}' update_fn='CountNullsUpdate' merge_fn='CountNullsMerge'; + +create aggregate function {database}.agg_intermediate(int) +returns bigint intermediate string location '{location}' +init_fn='AggIntermediateInit' update_fn='AggIntermediateUpdate' +merge_fn='AggIntermediateMerge' finalize_fn='AggIntermediateFinalize'; + +create aggregate function {database}.agg_decimal_intermediate(decimal(2,1), int) +returns decimal(6,5) intermediate decimal(4,3) location '{location}' +init_fn='AggDecimalIntermediateInit' update_fn='AggDecimalIntermediateUpdate' +merge_fn='AggDecimalIntermediateMerge' finalize_fn='AggDecimalIntermediateFinalize'; """ # Create test UDF functions in {database} from library {location}
