IMPALA-1430,IMPALA-4878,IMPALA-4879: codegen native UDAs

This uses the existing infrastructure for codegening builtin UDAs and
for codegening calls to UDFs. GetUdf() is refactored to support both
UDFs and UDAs.

IR UDAs are still not allowed by the frontend. It's unclear if we want
to enable them going forward because of the difficulties in testing and
supporting IR UDFs/UDAs.

This also fixes some bugs with the Get*Type() methods of
FunctionContext. GetArgType() and related methods now always return the
logical input types of the aggregate function. Getting the tests to pass
required fixing IMPALA-4878 because they called GetIntermediateType().

Testing:
test_udfs.py tests UDAs with codegen enabled and disabled.

Added some asserts to test UDAs to check that the correct types are
passed in via the FunctionContext.

Change-Id: Id1708eaa96eb76fb9bec5eeabf209f81c88eec2f
Reviewed-on: http://gerrit.cloudera.org:8080/5161
Reviewed-by: Dan Hecht <[email protected]>
Tested-by: Impala Public Jenkins


Project: http://git-wip-us.apache.org/repos/asf/incubator-impala/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-impala/commit/d2d3f4c1
Tree: http://git-wip-us.apache.org/repos/asf/incubator-impala/tree/d2d3f4c1
Diff: http://git-wip-us.apache.org/repos/asf/incubator-impala/diff/d2d3f4c1

Branch: refs/heads/master
Commit: d2d3f4c1a6eefb3f1335da8bd6791fbebd63d98b
Parents: 1335af3
Author: Tim Armstrong <[email protected]>
Authored: Fri Nov 18 13:34:15 2016 -0800
Committer: Impala Public Jenkins <[email protected]>
Committed: Fri Feb 10 02:18:32 2017 +0000

----------------------------------------------------------------------
 be/src/codegen/codegen-anyval.cc                |   4 +-
 be/src/codegen/codegen-anyval.h                 |   4 +-
 be/src/codegen/llvm-codegen.cc                  | 113 +++++++++++++-
 be/src/codegen/llvm-codegen.h                   |  25 ++-
 be/src/exec/partitioned-aggregation-node.cc     |  26 +---
 be/src/exprs/agg-fn-evaluator.cc                |  64 +++++---
 be/src/exprs/agg-fn-evaluator.h                 |  21 +--
 be/src/exprs/anyval-util.cc                     |  34 ++--
 be/src/exprs/anyval-util.h                      |   2 +
 be/src/exprs/expr.cc                            |   8 +-
 be/src/exprs/expr.h                             |   9 +-
 be/src/exprs/scalar-fn-call.cc                  | 154 +++----------------
 be/src/exprs/scalar-fn-call.h                   |   5 +-
 be/src/exprs/timestamp-functions.cc             |   4 +-
 be/src/runtime/types.cc                         |   6 +
 be/src/runtime/types.h                          |   2 +
 be/src/testutil/test-udas.cc                    | 132 ++++++++++++++--
 be/src/testutil/test-udas.h                     |   1 +
 be/src/udf/udf-internal.h                       |   6 +
 be/src/udf/udf-ir.cc                            |   4 +
 be/src/udf/udf.h                                |  18 ++-
 common/thrift/Exprs.thrift                      |   5 +
 .../apache/impala/analysis/AggregateInfo.java   |  10 +-
 .../impala/analysis/FunctionCallExpr.java       |  61 ++++++--
 .../functional-query/queries/QueryTest/uda.test |  19 +++
 tests/query_test/test_udfs.py                   |  10 ++
 26 files changed, 500 insertions(+), 247 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/codegen/codegen-anyval.cc
----------------------------------------------------------------------
diff --git a/be/src/codegen/codegen-anyval.cc b/be/src/codegen/codegen-anyval.cc
index bd27409..a778812 100644
--- a/be/src/codegen/codegen-anyval.cc
+++ b/be/src/codegen/codegen-anyval.cc
@@ -67,7 +67,7 @@ Type* CodegenAnyVal::GetLoweredType(LlvmCodeGen* cg, const 
ColumnType& type) {
   }
 }
 
-Type* CodegenAnyVal::GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& 
type) {
+PointerType* CodegenAnyVal::GetLoweredPtrType(LlvmCodeGen* cg, const 
ColumnType& type) {
   return GetLoweredType(cg, type)->getPointerTo();
 }
 
@@ -116,7 +116,7 @@ Type* CodegenAnyVal::GetUnloweredType(LlvmCodeGen* cg, 
const ColumnType& type) {
   return result;
 }
 
-Type* CodegenAnyVal::GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& 
type) {
+PointerType* CodegenAnyVal::GetUnloweredPtrType(LlvmCodeGen* cg, const 
ColumnType& type) {
   return GetUnloweredType(cg, type)->getPointerTo();
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/codegen/codegen-anyval.h
----------------------------------------------------------------------
diff --git a/be/src/codegen/codegen-anyval.h b/be/src/codegen/codegen-anyval.h
index c07f3eb..13494ac 100644
--- a/be/src/codegen/codegen-anyval.h
+++ b/be/src/codegen/codegen-anyval.h
@@ -95,7 +95,7 @@ class CodegenAnyVal {
 
   /// Returns the lowered AnyVal pointer type associated with 'type'.
   /// E.g.: TYPE_BOOLEAN => i16*
-  static llvm::Type* GetLoweredPtrType(LlvmCodeGen* cg, const ColumnType& 
type);
+  static llvm::PointerType* GetLoweredPtrType(LlvmCodeGen* cg, const 
ColumnType& type);
 
   /// Returns the unlowered AnyVal type associated with 'type'.
   /// E.g.: TYPE_BOOLEAN => %"struct.impala_udf::BooleanVal"
@@ -103,7 +103,7 @@ class CodegenAnyVal {
 
   /// Returns the unlowered AnyVal pointer type associated with 'type'.
   /// E.g.: TYPE_BOOLEAN => %"struct.impala_udf::BooleanVal"*
-  static llvm::Type* GetUnloweredPtrType(LlvmCodeGen* cg, const ColumnType& 
type);
+  static llvm::PointerType* GetUnloweredPtrType(LlvmCodeGen* cg, const 
ColumnType& type);
 
   /// Return the constant type-lowered value corresponding to a null *Val.
   /// E.g.: passing TYPE_DOUBLE (corresponding to the lowered DoubleVal { i8, 
double })

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/codegen/llvm-codegen.cc
----------------------------------------------------------------------
diff --git a/be/src/codegen/llvm-codegen.cc b/be/src/codegen/llvm-codegen.cc
index 3d730d5..6cf43f6 100644
--- a/be/src/codegen/llvm-codegen.cc
+++ b/be/src/codegen/llvm-codegen.cc
@@ -70,6 +70,7 @@
 #include "util/hdfs-util.h"
 #include "util/path-builder.h"
 #include "util/runtime-profile-counters.h"
+#include "util/symbols-util.h"
 #include "util/test-info.h"
 
 #include "common/names.h"
@@ -766,8 +767,116 @@ Function* LlvmCodeGen::FnPrototype::GeneratePrototype(
   return fn;
 }
 
-int LlvmCodeGen::ReplaceCallSites(Function* caller, Function* new_fn,
-    const string& target_name) {
+Status LlvmCodeGen::LoadFunction(const TFunction& fn, const std::string& 
symbol,
+    const ColumnType* return_type, const std::vector<ColumnType>& arg_types,
+    int num_fixed_args, bool has_varargs, Function** llvm_fn,
+    LibCacheEntry** cache_entry) {
+  DCHECK_GE(arg_types.size(), num_fixed_args);
+  DCHECK(has_varargs || arg_types.size() == num_fixed_args);
+  DCHECK(!has_varargs || arg_types.size() > num_fixed_args);
+  // from_utc_timestamp() and to_utc_timestamp() have inline ASM that cannot 
be JIT'd.
+  // TimestampFunctions::AddSub() contains a try/catch which doesn't work in 
JIT'd
+  // code. Always use the interpreted version of these functions.
+  // TODO: fix these built-in functions so we don't need 'broken_builtin' 
below.
+  bool broken_builtin = fn.name.function_name == "from_utc_timestamp"
+      || fn.name.function_name == "to_utc_timestamp"
+      || symbol.find("AddSub") != string::npos;
+  if (fn.binary_type == TFunctionBinaryType::NATIVE
+      || (fn.binary_type == TFunctionBinaryType::BUILTIN && broken_builtin)) {
+    // In this path, we are calling a precompiled native function, either a UDF
+    // in a .so or a builtin using the UDF interface.
+    void* fn_ptr;
+    Status status = LibCache::instance()->GetSoFunctionPtr(
+        fn.hdfs_location, symbol, &fn_ptr, cache_entry);
+    if (!status.ok() && fn.binary_type == TFunctionBinaryType::BUILTIN) {
+      // Builtins symbols should exist unless there is a version mismatch.
+      status.AddDetail(
+          ErrorMsg(TErrorCode::MISSING_BUILTIN, fn.name.function_name, 
symbol).msg());
+    }
+    RETURN_IF_ERROR(status);
+    DCHECK(fn_ptr != NULL);
+
+    // Per the x64 ABI, DecimalVals are returned via a DecimalVal* output 
argument.
+    // So, the return type is void.
+    bool is_decimal = return_type != NULL && return_type->type == TYPE_DECIMAL;
+    Type* llvm_return_type = return_type == NULL || is_decimal ?
+        void_type() :
+        CodegenAnyVal::GetLoweredType(this, *return_type);
+
+    // Convert UDF function pointer to Function*. Start by creating a function
+    // prototype for it.
+    FnPrototype prototype(this, symbol, llvm_return_type);
+
+    if (is_decimal) {
+      // Per the x64 ABI, DecimalVals are returned via a DecmialVal* output 
argument
+      Type* output_type = CodegenAnyVal::GetUnloweredPtrType(this, 
*return_type);
+      prototype.AddArgument("output", output_type);
+    }
+
+    // The "FunctionContext*" argument.
+    prototype.AddArgument("ctx", 
GetPtrType("class.impala_udf::FunctionContext"));
+
+    // The "fixed" arguments for the UDF function, followed by the variable 
arguments,
+    // if any.
+    for (int i = 0; i < num_fixed_args; ++i) {
+      Type* arg_type = CodegenAnyVal::GetUnloweredPtrType(this, arg_types[i]);
+      prototype.AddArgument(Substitute("fixed_arg_$0", i), arg_type);
+    }
+
+    if (has_varargs) {
+      prototype.AddArgument("num_var_arg", GetType(TYPE_INT));
+      // Get the vararg type from the first vararg.
+      prototype.AddArgument(
+          "var_arg", CodegenAnyVal::GetUnloweredPtrType(this, 
arg_types[num_fixed_args]));
+    }
+
+    // Create a Function* with the generated type. This is only a function
+    // declaration, not a definition, since we do not create any basic blocks 
or
+    // instructions in it.
+    *llvm_fn = prototype.GeneratePrototype(NULL, NULL, false);
+
+    // Associate the dynamically loaded function pointer with the Function* we 
defined.
+    // This tells LLVM where the compiled function definition is located in 
memory.
+    execution_engine_->addGlobalMapping(*llvm_fn, fn_ptr);
+  } else if (fn.binary_type == TFunctionBinaryType::BUILTIN) {
+    // In this path, we're running a builtin with the UDF interface. The IR is
+    // in the llvm module. Builtin functions may use Expr::GetConstant(). 
Clone the
+    // function so that we can replace constants in the copied function.
+    *llvm_fn = GetFunction(symbol, true);
+    if (*llvm_fn == NULL) {
+      // Builtins symbols should exist unless there is a version mismatch.
+      return Status(Substitute("Builtin '$0' with symbol '$1' does not exist. 
Verify "
+                               "that all your impalads are the same version.",
+          fn.name.function_name, symbol));
+    }
+    // Rename the function to something more readable than the mangled name.
+    string demangled_name = 
SymbolsUtil::DemangleNoArgs((*llvm_fn)->getName().str());
+    (*llvm_fn)->setName(demangled_name);
+  } else {
+    // We're running an IR UDF.
+    DCHECK_EQ(fn.binary_type, TFunctionBinaryType::IR);
+
+    string local_path;
+    RETURN_IF_ERROR(LibCache::instance()->GetLocalLibPath(
+        fn.hdfs_location, LibCache::TYPE_IR, &local_path));
+    // Link the UDF module into this query's main module so the UDF's 
functions are
+    // available in the main module.
+    RETURN_IF_ERROR(LinkModule(local_path));
+
+    *llvm_fn = GetFunction(symbol, true);
+    if (*llvm_fn == NULL) {
+      return Status(Substitute("Unable to load function '$0' from LLVM module 
'$1'",
+          symbol, fn.hdfs_location));
+    }
+    // Rename the function to something more readable than the mangled name.
+    string demangled_name = 
SymbolsUtil::DemangleNoArgs((*llvm_fn)->getName().str());
+    (*llvm_fn)->setName(demangled_name);
+  }
+  return Status::OK();
+}
+
+int LlvmCodeGen::ReplaceCallSites(
+    Function* caller, Function* new_fn, const string& target_name) {
   DCHECK(!is_compiled_);
   DCHECK(caller->getParent() == module_);
   DCHECK(caller != NULL);

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/codegen/llvm-codegen.h
----------------------------------------------------------------------
diff --git a/be/src/codegen/llvm-codegen.h b/be/src/codegen/llvm-codegen.h
index 9615664..d51faab 100644
--- a/be/src/codegen/llvm-codegen.h
+++ b/be/src/codegen/llvm-codegen.h
@@ -294,6 +294,23 @@ class LlvmCodeGen {
   /// functions.
   Status FinalizeModule();
 
+  /// Loads a native or IR function 'fn' with symbol 'symbol' from the 
builtins or
+  /// an external library and puts the result in *llvm_fn. *llvm_fn can be 
safely
+  /// modified in place, because it is either newly generated or cloned. The 
caller must
+  /// call FinalizeFunction() on 'llvm_fn' once it is done modifying it. The 
function has
+  /// return type 'return_type' (void if 'return_type' is NULL) and input 
argument types
+  /// 'arg_types'. The first 'num_fixed_args' arguments are fixed arguments, 
and the
+  /// remaining arguments are varargs. 'has_varargs' indicates whether the 
function
+  /// accepts varargs. If 'has_varargs' is true, there must be at least one 
vararg. If
+  /// the function is loaded from a library, 'cache_entry' is updated to point 
to the
+  /// library containing the function. If 'cache_entry' is set to a non-NULL 
value by
+  /// this function, the caller must call LibCache::DecrementUseCount() on it 
when done
+  /// using the function.
+  Status LoadFunction(const TFunction& fn, const std::string& symbol,
+      const ColumnType* return_type, const std::vector<ColumnType>& arg_types,
+      int num_fixed_args, bool has_varargs, llvm::Function** llvm_fn,
+      LibCacheEntry** cache_entry);
+
   /// Replaces all instructions in 'caller' that call 'target_name' with a call
   /// instruction to 'new_fn'. Returns the number of call sites updated.
   ///
@@ -485,10 +502,6 @@ class LlvmCodeGen {
   llvm::Value* CodegenArrayAt(
       LlvmBuilder*, llvm::Value* array, int idx, const char* name = "");
 
-  /// Loads a module at 'file' and links it to the module associated with
-  /// this LlvmCodeGen object. The module must be on the local filesystem.
-  Status LinkModule(const std::string& file);
-
   /// If there are more than this number of expr trees (or functions that 
evaluate
   /// expressions), avoid inlining avoid inlining for the exprs exceeding this 
threshold.
   static const int CODEGEN_INLINE_EXPRS_THRESHOLD = 100;
@@ -538,6 +551,10 @@ class LlvmCodeGen {
   Status LoadModuleFromMemory(std::unique_ptr<llvm::MemoryBuffer> 
module_ir_buf,
       std::string module_name, std::unique_ptr<llvm::Module>* module);
 
+  /// Loads a module at 'file' and links it to the module associated with
+  /// this LlvmCodeGen object. The module must be on the local filesystem.
+  Status LinkModule(const std::string& file);
+
   /// Strip global constructors and destructors from an LLVM module. We never 
run them
   /// anyway (they must be explicitly invoked) so it is dead code.
   static void StripGlobalCtorsDtors(llvm::Module* module);

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exec/partitioned-aggregation-node.cc
----------------------------------------------------------------------
diff --git a/be/src/exec/partitioned-aggregation-node.cc 
b/be/src/exec/partitioned-aggregation-node.cc
index 6cc36be..7c75d01 100644
--- a/be/src/exec/partitioned-aggregation-node.cc
+++ b/be/src/exec/partitioned-aggregation-node.cc
@@ -1721,22 +1721,8 @@ Status 
PartitionedAggregationNode::CodegenCallUda(LlvmCodeGen* codegen,
     const vector<CodegenAnyVal>& input_vals, const CodegenAnyVal& dst,
     CodegenAnyVal* updated_dst_val) {
   DCHECK_EQ(evaluator->input_expr_ctxs().size(), input_vals.size());
-  const string& symbol =
-      evaluator->is_merge() ? evaluator->merge_symbol() : 
evaluator->update_symbol();
-  const ColumnType& dst_type = evaluator->intermediate_type();
-
-  // TODO: to support actual UDAs, not just builtin functions using the UDA 
interface,
-  // we need to load the function at this point.
-  Function* uda_fn = codegen->GetFunction(symbol, true);
-  DCHECK(uda_fn != NULL);
-
-  vector<FunctionContext::TypeDesc> arg_types;
-  for (int i = 0; i < evaluator->input_expr_ctxs().size(); ++i) {
-    arg_types.push_back(AnyValUtil::ColumnTypeToTypeDesc(
-        evaluator->input_expr_ctxs()[i]->root()->type()));
-  }
-  Expr::InlineConstants(
-      AnyValUtil::ColumnTypeToTypeDesc(dst_type), arg_types, codegen, uda_fn);
+  Function* uda_fn;
+  RETURN_IF_ERROR(evaluator->GetUpdateOrMergeFunction(codegen, &uda_fn));
 
   // Set up arguments for call to UDA, which are the FunctionContext*, 
followed by
   // pointers to all input values, followed by a pointer to the destination 
value.
@@ -1753,6 +1739,7 @@ Status 
PartitionedAggregationNode::CodegenCallUda(LlvmCodeGen* codegen,
   // Create pointer to dst to pass to uda_fn. We must use the unlowered type 
for the
   // same reason as above.
   Value* dst_lowered_ptr = dst.GetLoweredPtr("dst_lowered_ptr");
+  const ColumnType& dst_type = evaluator->intermediate_type();
   Type* dst_unlowered_ptr_type = CodegenAnyVal::GetUnloweredPtrType(codegen, 
dst_type);
   Value* dst_unlowered_ptr = builder->CreateBitCast(
       dst_lowered_ptr, dst_unlowered_ptr_type, "dst_unlowered_ptr");
@@ -1825,13 +1812,6 @@ Status PartitionedAggregationNode::CodegenUpdateTuple(
                   "intermediate tuple desc");
   }
 
-  for (AggFnEvaluator* evaluator : aggregate_evaluators_) {
-    // Don't codegen things that aren't builtins (for now)
-    if (!evaluator->is_builtin()) {
-      return Status("PartitionedAggregationNode::CodegenUpdateTuple(): UDA 
codegen NYI");
-    }
-  }
-
   // Get the types to match the UpdateTuple signature
   Type* agg_node_type = 
codegen->GetType(PartitionedAggregationNode::LLVM_CLASS_NAME);
   Type* fn_ctx_type = 
codegen->GetType(FunctionContextImpl::LLVM_FUNCTIONCONTEXT_NAME);

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/agg-fn-evaluator.cc
----------------------------------------------------------------------
diff --git a/be/src/exprs/agg-fn-evaluator.cc b/be/src/exprs/agg-fn-evaluator.cc
index 4c6a993..b623130 100644
--- a/be/src/exprs/agg-fn-evaluator.cc
+++ b/be/src/exprs/agg-fn-evaluator.cc
@@ -23,8 +23,9 @@
 #include "common/logging.h"
 #include "exec/aggregation-node.h"
 #include "exprs/aggregate-functions.h"
-#include "exprs/expr-context.h"
 #include "exprs/anyval-util.h"
+#include "exprs/expr-context.h"
+#include "exprs/scalar-fn-call.h"
 #include "runtime/lib-cache.h"
 #include "runtime/raw-value.h"
 #include "runtime/runtime-state.h"
@@ -94,6 +95,8 @@ AggFnEvaluator::AggFnEvaluator(const TExprNode& desc, bool 
is_analytic_fn)
     is_analytic_fn_(is_analytic_fn),
     intermediate_slot_desc_(NULL),
     output_slot_desc_(NULL),
+    arg_type_descs_(AnyValUtil::ColumnTypesToTypeDescs(
+        ColumnType::FromThrift(desc.agg_expr.arg_types))),
     cache_entry_(NULL),
     init_fn_(NULL),
     update_fn_(NULL),
@@ -198,28 +201,15 @@ Status AggFnEvaluator::Prepare(RuntimeState* state, const 
RowDescriptor& desc,
         &cache_entry_));
   }
   if (!fn_.aggregate_fn.remove_fn_symbol.empty()) {
-    RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(
-        fn_.hdfs_location, fn_.aggregate_fn.remove_fn_symbol, &remove_fn_,
-        &cache_entry_));
+    RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
+        fn_.aggregate_fn.remove_fn_symbol, &remove_fn_, &cache_entry_));
   }
   if (!fn_.aggregate_fn.finalize_fn_symbol.empty()) {
-    RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(
-        fn_.hdfs_location, fn_.aggregate_fn.finalize_fn_symbol, &finalize_fn_,
-        &cache_entry_));
-  }
-
-  vector<FunctionContext::TypeDesc> arg_types;
-  for (int i = 0; i < input_expr_ctxs_.size(); ++i) {
-    arg_types.push_back(
-        AnyValUtil::ColumnTypeToTypeDesc(input_expr_ctxs_[i]->root()->type()));
+    RETURN_IF_ERROR(LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location,
+        fn_.aggregate_fn.finalize_fn_symbol, &finalize_fn_, &cache_entry_));
   }
-
-  FunctionContext::TypeDesc intermediate_type =
-      AnyValUtil::ColumnTypeToTypeDesc(intermediate_slot_desc_->type());
-  FunctionContext::TypeDesc output_type =
-       AnyValUtil::ColumnTypeToTypeDesc(output_slot_desc_->type());
-  *agg_fn_ctx = FunctionContextImpl::CreateContext(
-      state, agg_fn_pool, intermediate_type, output_type, arg_types);
+  *agg_fn_ctx = FunctionContextImpl::CreateContext(state, agg_fn_pool,
+      GetIntermediateTypeDesc(), GetOutputTypeDesc(), arg_type_descs_);
   return Status::OK();
 }
 
@@ -521,6 +511,40 @@ void AggFnEvaluator::SerializeOrFinalize(FunctionContext* 
agg_fn_ctx, Tuple* src
   }
 }
 
+/// Gets the update or merge function for this UDA.
+Status AggFnEvaluator::GetUpdateOrMergeFunction(LlvmCodeGen* codegen, 
Function** uda_fn) {
+  const string& symbol =
+      is_merge_ ? fn_.aggregate_fn.merge_fn_symbol : 
fn_.aggregate_fn.update_fn_symbol;
+  vector<ColumnType> fn_arg_types;
+  for (ExprContext* input_expr_ctx : input_expr_ctxs_) {
+    fn_arg_types.push_back(input_expr_ctx->root()->type());
+  }
+  // The intermediate value is passed as the last argument.
+  fn_arg_types.push_back(intermediate_type());
+  RETURN_IF_ERROR(codegen->LoadFunction(fn_, symbol, NULL, fn_arg_types,
+      fn_arg_types.size(), false, uda_fn, &cache_entry_));
+
+  // Inline constants into the function body (if there is an IR body).
+  if (!(*uda_fn)->isDeclaration()) {
+    // TODO: IMPALA-4785: we should also replace references to 
GetIntermediateType()
+    // with constants.
+    Expr::InlineConstants(GetOutputTypeDesc(), arg_type_descs_, codegen, 
*uda_fn);
+    *uda_fn = codegen->FinalizeFunction(*uda_fn);
+    if (*uda_fn == NULL) {
+      return Status(TErrorCode::UDF_VERIFY_FAILED, symbol, fn_.hdfs_location);
+    }
+  }
+  return Status::OK();
+}
+
+FunctionContext::TypeDesc AggFnEvaluator::GetIntermediateTypeDesc() const {
+  return AnyValUtil::ColumnTypeToTypeDesc(intermediate_slot_desc_->type());
+}
+
+FunctionContext::TypeDesc AggFnEvaluator::GetOutputTypeDesc() const {
+  return AnyValUtil::ColumnTypeToTypeDesc(output_slot_desc_->type());
+}
+
 string AggFnEvaluator::DebugString(const vector<AggFnEvaluator*>& exprs) {
   stringstream out;
   out << "[";

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/agg-fn-evaluator.h
----------------------------------------------------------------------
diff --git a/be/src/exprs/agg-fn-evaluator.h b/be/src/exprs/agg-fn-evaluator.h
index b3ecda0..712bf40 100644
--- a/be/src/exprs/agg-fn-evaluator.h
+++ b/be/src/exprs/agg-fn-evaluator.h
@@ -119,8 +119,6 @@ class AggFnEvaluator {
   bool SupportsRemove() const { return remove_fn_ != NULL; }
   bool SupportsSerialize() const { return serialize_fn_ != NULL; }
   const std::string& fn_name() const { return fn_.name.function_name; }
-  const std::string& update_symbol() const { return 
fn_.aggregate_fn.update_fn_symbol; }
-  const std::string& merge_symbol() const { return 
fn_.aggregate_fn.merge_fn_symbol; }
   const SlotDescriptor* output_slot_desc() const { return output_slot_desc_; }
 
   static std::string DebugString(const std::vector<AggFnEvaluator*>& exprs);
@@ -168,14 +166,8 @@ class AggFnEvaluator {
   static void Finalize(const std::vector<AggFnEvaluator*>& evaluators,
       const std::vector<FunctionContext*>& fn_ctxs, Tuple* src, Tuple* dst);
 
-  /// TODO: implement codegen path. These functions would return IR functions 
with
-  /// the same signature as the interpreted ones above.
-  /// Function* GetIrInitFn();
-  /// Function* GetIrAddFn();
-  /// Function* GetIrRemoveFn();
-  /// Function* GetIrSerializeFn();
-  /// Function* GetIrGetValueFn();
-  /// Function* GetIrFinalizeFn();
+  /// Gets the codegened update or merge function for this aggregate function.
+  Status GetUpdateOrMergeFunction(LlvmCodeGen* codegen, llvm::Function** 
uda_fn);
 
  private:
   const TFunction fn_;
@@ -195,6 +187,9 @@ class AggFnEvaluator {
   /// expression (e.g. count(*)).
   std::vector<ExprContext*> input_expr_ctxs_;
 
+  /// The types of the arguments to the aggregate function.
+  const std::vector<FunctionContext::TypeDesc> arg_type_descs_;
+
   /// The enum for some of the builtins that still require special cased logic.
   AggregationOp agg_op_;
 
@@ -221,6 +216,12 @@ class AggFnEvaluator {
   /// Use Create() instead.
   AggFnEvaluator(const TExprNode& desc, bool is_analytic_fn);
 
+  /// Return the intermediate type of the aggregate function.
+  FunctionContext::TypeDesc GetIntermediateTypeDesc() const;
+
+  /// Return the output type of the aggregate function.
+  FunctionContext::TypeDesc GetOutputTypeDesc() const;
+
   /// TODO: these functions below are not extensible and we need to use 
codegen to
   /// generate the calls into the UDA functions (like for UDFs).
   /// Remove these functions when this is supported.

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/anyval-util.cc
----------------------------------------------------------------------
diff --git a/be/src/exprs/anyval-util.cc b/be/src/exprs/anyval-util.cc
index 132d6e4..c49cdb3 100644
--- a/be/src/exprs/anyval-util.cc
+++ b/be/src/exprs/anyval-util.cc
@@ -92,17 +92,33 @@ FunctionContext::TypeDesc 
AnyValUtil::ColumnTypeToTypeDesc(const ColumnType& typ
   return out;
 }
 
+vector<FunctionContext::TypeDesc> AnyValUtil::ColumnTypesToTypeDescs(
+    const vector<ColumnType>& types) {
+  vector<FunctionContext::TypeDesc> type_descs;
+  for (const ColumnType& type : types) 
type_descs.push_back(ColumnTypeToTypeDesc(type));
+  return type_descs;
+}
+
 ColumnType AnyValUtil::TypeDescToColumnType(const FunctionContext::TypeDesc& 
type) {
   switch (type.type) {
-    case FunctionContext::TYPE_BOOLEAN: return ColumnType(TYPE_BOOLEAN);
-    case FunctionContext::TYPE_TINYINT: return ColumnType(TYPE_TINYINT);
-    case FunctionContext::TYPE_SMALLINT: return ColumnType(TYPE_SMALLINT);
-    case FunctionContext::TYPE_INT: return ColumnType(TYPE_INT);
-    case FunctionContext::TYPE_BIGINT: return ColumnType(TYPE_BIGINT);
-    case FunctionContext::TYPE_FLOAT: return ColumnType(TYPE_FLOAT);
-    case FunctionContext::TYPE_DOUBLE: return ColumnType(TYPE_DOUBLE);
-    case FunctionContext::TYPE_TIMESTAMP: return ColumnType(TYPE_TIMESTAMP);
-    case FunctionContext::TYPE_STRING: return ColumnType(TYPE_STRING);
+    case FunctionContext::TYPE_BOOLEAN:
+      return ColumnType(TYPE_BOOLEAN);
+    case FunctionContext::TYPE_TINYINT:
+      return ColumnType(TYPE_TINYINT);
+    case FunctionContext::TYPE_SMALLINT:
+      return ColumnType(TYPE_SMALLINT);
+    case FunctionContext::TYPE_INT:
+      return ColumnType(TYPE_INT);
+    case FunctionContext::TYPE_BIGINT:
+      return ColumnType(TYPE_BIGINT);
+    case FunctionContext::TYPE_FLOAT:
+      return ColumnType(TYPE_FLOAT);
+    case FunctionContext::TYPE_DOUBLE:
+      return ColumnType(TYPE_DOUBLE);
+    case FunctionContext::TYPE_TIMESTAMP:
+      return ColumnType(TYPE_TIMESTAMP);
+    case FunctionContext::TYPE_STRING:
+      return ColumnType(TYPE_STRING);
     case FunctionContext::TYPE_DECIMAL:
       return ColumnType::CreateDecimalType(type.precision, type.scale);
     case FunctionContext::TYPE_FIXED_BUFFER:

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/anyval-util.h
----------------------------------------------------------------------
diff --git a/be/src/exprs/anyval-util.h b/be/src/exprs/anyval-util.h
index e5473c7..429322a 100644
--- a/be/src/exprs/anyval-util.h
+++ b/be/src/exprs/anyval-util.h
@@ -227,6 +227,8 @@ class AnyValUtil {
   }
 
   static FunctionContext::TypeDesc ColumnTypeToTypeDesc(const ColumnType& 
type);
+  static std::vector<FunctionContext::TypeDesc> ColumnTypesToTypeDescs(
+      const std::vector<ColumnType>& types);
   // Note: constructing a ColumnType is expensive and should be avoided in 
query execution
   // paths (i.e. non-setup paths).
   static ColumnType TypeDescToColumnType(const FunctionContext::TypeDesc& 
type);

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/expr.cc
----------------------------------------------------------------------
diff --git a/be/src/exprs/expr.cc b/be/src/exprs/expr.cc
index 2ffddf6..f119a9d 100644
--- a/be/src/exprs/expr.cc
+++ b/be/src/exprs/expr.cc
@@ -637,10 +637,10 @@ int Expr::InlineConstants(LlvmCodeGen* codegen, Function* 
fn) {
 }
 
 int Expr::InlineConstants(const FunctionContext::TypeDesc& return_type,
-      const std::vector<FunctionContext::TypeDesc>& arg_types, LlvmCodeGen* 
codegen,
-      Function* fn) {
+    const std::vector<FunctionContext::TypeDesc>& arg_types, LlvmCodeGen* 
codegen,
+    Function* fn) {
   int replaced = 0;
-  for (inst_iterator iter = inst_begin(fn), end = inst_end(fn); iter != end; ) 
{
+  for (inst_iterator iter = inst_begin(fn), end = inst_end(fn); iter != end;) {
     // Increment iter now so we don't mess it up modifying the instruction 
below
     Instruction* instr = &*(iter++);
 
@@ -666,7 +666,7 @@ int Expr::InlineConstants(const FunctionContext::TypeDesc& 
return_type,
     int i_val = static_cast<int>(i_arg->getSExtValue());
     // All supported constants are currently integers.
     call_instr->replaceAllUsesWith(ConstantInt::get(codegen->GetType(TYPE_INT),
-          GetConstantInt(return_type, arg_types, c_val, i_val)));
+        GetConstantInt(return_type, arg_types, c_val, i_val)));
     call_instr->eraseFromParent();
     ++replaced;
   }

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/expr.h
----------------------------------------------------------------------
diff --git a/be/src/exprs/expr.h b/be/src/exprs/expr.h
index b77c9a2..13ba312 100644
--- a/be/src/exprs/expr.h
+++ b/be/src/exprs/expr.h
@@ -263,9 +263,11 @@ class Expr {
   // Any additions to this enum must be reflected in both GetConstant*() and
   // GetIrConstant().
   enum ExprConstant {
+    // RETURN_TYPE_*: properties of FunctionContext::GetReturnType().
     RETURN_TYPE_SIZE, // int
     RETURN_TYPE_PRECISION, // int
     RETURN_TYPE_SCALE, // int
+    // ARG_TYPE_* with parameter i: properties of 
FunctionContext::GetArgType(i).
     ARG_TYPE_SIZE, // int[]
     ARG_TYPE_PRECISION, // int[]
     ARG_TYPE_SCALE, // int[]
@@ -289,7 +291,8 @@ class Expr {
   // constants to be replaced must be inlined into the function that 
InlineConstants()
   // is run on (e.g. by annotating them with IR_ALWAYS_INLINE).
   //
-  // TODO: implement a loop unroller (or use LLVM's) so we can use 
GetConstantInt() in loops
+  // TODO: implement a loop unroller (or use LLVM's) so we can use 
GetConstantInt() in
+  // loops
   static int GetConstantInt(const FunctionContext& ctx, ExprConstant c, int i 
= -1);
 
   /// Finds all calls to Expr::GetConstantInt() in 'fn' and replaces them with 
the
@@ -298,8 +301,8 @@ class Expr {
   /// 'arg_types' are the argument types of the UDF or UDAF, i.e. the values of
   /// FunctionContext::GetArgType().
   static int InlineConstants(const FunctionContext::TypeDesc& return_type,
-      const std::vector<FunctionContext::TypeDesc>& arg_types,
-      LlvmCodeGen* codegen, llvm::Function* fn);
+      const std::vector<FunctionContext::TypeDesc>& arg_types, LlvmCodeGen* 
codegen,
+      llvm::Function* fn);
 
   static const char* LLVM_CLASS_NAME;
 

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/scalar-fn-call.cc
----------------------------------------------------------------------
diff --git a/be/src/exprs/scalar-fn-call.cc b/be/src/exprs/scalar-fn-call.cc
index c7d4e7e..06830b7 100644
--- a/be/src/exprs/scalar-fn-call.cc
+++ b/be/src/exprs/scalar-fn-call.cc
@@ -19,12 +19,12 @@
 
 #include <vector>
 #include <gutil/strings/substitute.h>
-#include <llvm/IR/Attributes.h>
 #include <llvm/ExecutionEngine/ExecutionEngine.h>
+#include <llvm/IR/Attributes.h>
 
 #include <boost/preprocessor/punctuation/comma_if.hpp>
-#include <boost/preprocessor/repetition/repeat.hpp>
 #include <boost/preprocessor/repetition/enum_params.hpp>
+#include <boost/preprocessor/repetition/repeat.hpp>
 #include <boost/preprocessor/repetition/repeat_from_to.hpp>
 
 #include "codegen/codegen-anyval.h"
@@ -37,8 +37,6 @@
 #include "runtime/types.h"
 #include "udf/udf-internal.h"
 #include "util/debug-util.h"
-#include "util/dynamic-util.h"
-#include "util/symbols-util.h"
 
 #include "common/names.h"
 
@@ -311,21 +309,28 @@ Status ScalarFnCall::GetCodegendComputeFn(LlvmCodeGen* 
codegen, Function** fn) {
     }
   }
 
+  vector<ColumnType> arg_types;
+  for (const Expr* child : children_) arg_types.push_back(child->type());
+  Function* udf;
+  RETURN_IF_ERROR(codegen->LoadFunction(fn_, fn_.scalar_fn.symbol, &type_, 
arg_types,
+      NumFixedArgs(), vararg_start_idx_ != -1, &udf, &cache_entry_));
+  // Inline constants into the function if it has an IR body.
+  if (!udf->isDeclaration()) {
+    InlineConstants(AnyValUtil::ColumnTypeToTypeDesc(type_),
+        AnyValUtil::ColumnTypesToTypeDescs(arg_types), codegen, udf);
+    udf = codegen->FinalizeFunction(udf);
+    if (udf == NULL) {
+      return Status(
+          TErrorCode::UDF_VERIFY_FAILED, fn_.scalar_fn.symbol, 
fn_.hdfs_location);
+    }
+  }
+
   if (fn_.binary_type == TFunctionBinaryType::IR) {
-    string local_path;
-    RETURN_IF_ERROR(LibCache::instance()->GetLocalLibPath(
-        fn_.hdfs_location, LibCache::TYPE_IR, &local_path));
-    // Link the UDF module into this query's main module (essentially copy the 
UDF
-    // module into the main module) so the UDF's functions are available in 
the main
-    // module.
-    RETURN_IF_ERROR(codegen->LinkModule(local_path));
-    // Load the Prepare() and Close() functions from the LLVM module.
+    // LoadFunction() should have linked the IR module into 'codegen'. Now 
load the
+    // Prepare() and Close() functions from 'codegen'.
     RETURN_IF_ERROR(LoadPrepareAndCloseFn(codegen));
   }
 
-  Function* udf;
-  RETURN_IF_ERROR(GetUdf(codegen, &udf));
-
   // Create wrapper that computes args and calls UDF
   stringstream fn_name;
   fn_name << udf->getName().str() << "Wrapper";
@@ -407,8 +412,7 @@ Status ScalarFnCall::GetCodegendComputeFn(LlvmCodeGen* 
codegen, Function** fn) {
     // Add the number of varargs
     udf_args.push_back(codegen->GetIntConstant(TYPE_INT, NumVarArgs()));
     // Add all the accumulated vararg inputs as one input argument.
-    PointerType* vararg_type =
-        codegen->GetPtrType(CodegenAnyVal::GetUnloweredType(codegen, 
VarArgsType()));
+    PointerType* vararg_type = CodegenAnyVal::GetUnloweredPtrType(codegen, 
VarArgsType());
     udf_args.push_back(builder.CreateBitCast(varargs_buffer, vararg_type, 
"varargs"));
   }
 
@@ -428,119 +432,11 @@ Status ScalarFnCall::GetCodegendComputeFn(LlvmCodeGen* 
codegen, Function** fn) {
   return Status::OK();
 }
 
-Status ScalarFnCall::GetUdf(LlvmCodeGen* codegen, Function** udf) {
-  // from_utc_timestamp() and to_utc_timestamp() have inline ASM that cannot 
be JIT'd.
-  // TimestampFunctions::AddSub() contains a try/catch which doesn't work in 
JIT'd
-  // code. Always use the interpreted version of these functions.
-  // TODO: fix these built-in functions so we don't need 'broken_builtin' 
below.
-  bool broken_builtin = fn_.name.function_name == "from_utc_timestamp" ||
-                        fn_.name.function_name == "to_utc_timestamp" ||
-                        fn_.scalar_fn.symbol.find("AddSub") != string::npos;
-  if (fn_.binary_type == TFunctionBinaryType::NATIVE ||
-      (fn_.binary_type == TFunctionBinaryType::BUILTIN && broken_builtin)) {
-    // In this path, we are code that has been statically compiled to assembly.
-    // This can either be a UDF implemented in a .so or a builtin using the UDF
-    // interface with the code in impalad.
-    void* fn_ptr;
-    Status status = LibCache::instance()->GetSoFunctionPtr(
-        fn_.hdfs_location, fn_.scalar_fn.symbol, &fn_ptr, &cache_entry_);
-    if (!status.ok() && fn_.binary_type == TFunctionBinaryType::BUILTIN) {
-      // Builtins symbols should exist unless there is a version mismatch.
-      status.AddDetail(ErrorMsg(TErrorCode::MISSING_BUILTIN,
-          fn_.name.function_name, fn_.scalar_fn.symbol).msg());
-    }
-    RETURN_IF_ERROR(status);
-    DCHECK(fn_ptr != NULL);
-
-    // Per the x64 ABI, DecimalVals are returned via a DecmialVal* output 
argument.
-    // So, the return type is void.
-    bool is_decimal = type().type == TYPE_DECIMAL;
-    Type* return_type = is_decimal ? codegen->void_type() :
-                                     CodegenAnyVal::GetLoweredType(codegen, 
type());
-
-    // Convert UDF function pointer to Function*. Start by creating a function
-    // prototype for it.
-    LlvmCodeGen::FnPrototype prototype(codegen, fn_.scalar_fn.symbol, 
return_type);
-
-    if (is_decimal) {
-      // Per the x64 ABI, DecimalVals are returned via a DecmialVal* output 
argument
-      Type* output_type =
-          codegen->GetPtrType(CodegenAnyVal::GetUnloweredType(codegen, 
type()));
-      prototype.AddArgument("output", output_type);
-    }
-
-    // The "FunctionContext*" argument.
-    prototype.AddArgument("ctx",
-        codegen->GetPtrType("class.impala_udf::FunctionContext"));
-
-    // The "fixed" arguments for the UDF function.
-    for (int i = 0; i < NumFixedArgs(); ++i) {
-      stringstream arg_name;
-      arg_name << "fixed_arg_" << i;
-      Type* arg_type = codegen->GetPtrType(
-          CodegenAnyVal::GetUnloweredType(codegen, children_[i]->type()));
-      prototype.AddArgument(arg_name.str(), arg_type);
-    }
-    // The varargs for the UDF function if there is any.
-    if (NumVarArgs() > 0) {
-      Type* vararg_type = CodegenAnyVal::GetUnloweredPtrType(
-          codegen, children_[vararg_start_idx_]->type());
-      prototype.AddArgument("num_var_arg", codegen->GetType(TYPE_INT));
-      prototype.AddArgument("var_arg", vararg_type);
-    }
-
-    // Create a Function* with the generated type. This is only a function
-    // declaration, not a definition, since we do not create any basic blocks 
or
-    // instructions in it.
-    *udf = prototype.GeneratePrototype(NULL, NULL, false);
-
-    // Associate the dynamically loaded function pointer with the Function* we 
defined.
-    // This tells LLVM where the compiled function definition is located in 
memory.
-    codegen->execution_engine()->addGlobalMapping(*udf, fn_ptr);
-  } else if (fn_.binary_type == TFunctionBinaryType::BUILTIN) {
-    // In this path, we're running a builtin with the UDF interface. The IR is
-    // in the llvm module.
-    *udf = codegen->GetFunction(fn_.scalar_fn.symbol, false);
-    if (*udf == NULL) {
-      // Builtins symbols should exist unless there is a version mismatch.
-      stringstream ss;
-      ss << "Builtin '" << fn_.name.function_name << "' with symbol '"
-         << fn_.scalar_fn.symbol << "' does not exist. "
-         << "Verify that all your impalads are the same version.";
-      return Status(ss.str());
-    }
-    // Builtin functions may use Expr::GetConstant(). Clone the function in 
case we need
-    // to use it again, and rename it to something more manageable than the 
mangled name.
-    string demangled_name = 
SymbolsUtil::DemangleNoArgs((*udf)->getName().str());
-    *udf = codegen->CloneFunction(*udf);
-    (*udf)->setName(demangled_name);
-    InlineConstants(codegen, *udf);
-    *udf = codegen->FinalizeFunction(*udf);
-    DCHECK(*udf != NULL);
-  } else {
-    // We're running an IR UDF.
-    DCHECK_EQ(fn_.binary_type, TFunctionBinaryType::IR);
-    *udf = codegen->GetFunction(fn_.scalar_fn.symbol, false);
-    if (*udf == NULL) {
-      stringstream ss;
-      ss << "Unable to locate function " << fn_.scalar_fn.symbol << " from 
LLVM module "
-         << fn_.hdfs_location;
-      return Status(ss.str());
-    }
-    *udf = codegen->FinalizeFunction(*udf);
-    if (*udf == NULL) {
-      return Status(
-          TErrorCode::UDF_VERIFY_FAILED, fn_.scalar_fn.symbol, 
fn_.hdfs_location);
-    }
-  }
-  return Status::OK();
-}
-
 Status ScalarFnCall::GetFunction(LlvmCodeGen* codegen, const string& symbol, 
void** fn) {
-  if (fn_.binary_type == TFunctionBinaryType::NATIVE ||
-      fn_.binary_type == TFunctionBinaryType::BUILTIN) {
-    return LibCache::instance()->GetSoFunctionPtr(fn_.hdfs_location, symbol, 
fn,
-        &cache_entry_);
+  if (fn_.binary_type == TFunctionBinaryType::NATIVE
+      || fn_.binary_type == TFunctionBinaryType::BUILTIN) {
+    return LibCache::instance()->GetSoFunctionPtr(
+        fn_.hdfs_location, symbol, fn, &cache_entry_);
   } else {
     DCHECK_EQ(fn_.binary_type, TFunctionBinaryType::IR);
     DCHECK(codegen != NULL);

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/scalar-fn-call.h
----------------------------------------------------------------------
diff --git a/be/src/exprs/scalar-fn-call.h b/be/src/exprs/scalar-fn-call.h
index c8bc8c8..ffb9a2f 100644
--- a/be/src/exprs/scalar-fn-call.h
+++ b/be/src/exprs/scalar-fn-call.h
@@ -48,7 +48,7 @@ class TExprNode;
 ///    - Test cancellation
 ///    - Type descs in UDA test harness
 ///    - Allow more functions to be NULL in UDA test harness
-class ScalarFnCall: public Expr {
+class ScalarFnCall : public Expr {
  public:
   virtual std::string DebugString() const;
 
@@ -117,9 +117,6 @@ class ScalarFnCall: public Expr {
     return children_.back()->type();
   }
 
-  /// Loads the native or IR function from HDFS and puts the result in *udf.
-  Status GetUdf(LlvmCodeGen* codegen, llvm::Function** udf);
-
   /// Loads the native or IR function 'symbol' from HDFS and puts the result 
in *fn.
   /// If the function is loaded from an IR module, it cannot be called until 
the module
   /// has been JIT'd (i.e. after GetCodegendComputeFn() has been called).

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/exprs/timestamp-functions.cc
----------------------------------------------------------------------
diff --git a/be/src/exprs/timestamp-functions.cc 
b/be/src/exprs/timestamp-functions.cc
index a519696..a63438c 100644
--- a/be/src/exprs/timestamp-functions.cc
+++ b/be/src/exprs/timestamp-functions.cc
@@ -76,7 +76,7 @@ void ThrowIfDateOutOfRange(const boost::gregorian::date& 
date) {
 
 // This function uses inline asm functions, which we believe to be from the 
boost library.
 // Inline asm is not currently supported by JIT, so this function should 
always be run in
-// the interpreted mode. This is handled in ScalarFnCall::GetUdf().
+// the interpreted mode. This is handled in LlvmCodeGen::LoadFunction().
 TimestampVal TimestampFunctions::FromUtc(FunctionContext* context,
     const TimestampVal& ts_val, const StringVal& tz_string_val) {
   if (ts_val.is_null || tz_string_val.is_null) return TimestampVal::null();
@@ -114,7 +114,7 @@ TimestampVal TimestampFunctions::FromUtc(FunctionContext* 
context,
 
 // This function uses inline asm functions, which we believe to be from the 
boost library.
 // Inline asm is not currently supported by JIT, so this function should 
always be run in
-// the interpreted mode. This is handled in ScalarFnCall::GetUdf().
+// the interpreted mode. This is handled in LlvmCodeGen::LoadFunction().
 TimestampVal TimestampFunctions::ToUtc(FunctionContext* context,
     const TimestampVal& ts_val, const StringVal& tz_string_val) {
   if (ts_val.is_null || tz_string_val.is_null) return TimestampVal::null();

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/runtime/types.cc
----------------------------------------------------------------------
diff --git a/be/src/runtime/types.cc b/be/src/runtime/types.cc
index 3a04ca3..f580628 100644
--- a/be/src/runtime/types.cc
+++ b/be/src/runtime/types.cc
@@ -310,6 +310,12 @@ string ColumnType::DebugString() const {
   }
 }
 
+vector<ColumnType> ColumnType::FromThrift(const vector<TColumnType>& ttypes) {
+  vector<ColumnType> types;
+  for (const TColumnType& ttype : ttypes) types.push_back(FromThrift(ttype));
+  return types;
+}
+
 ostream& operator<<(ostream& os, const ColumnType& type) {
   os << type.DebugString();
   return os;

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/runtime/types.h
----------------------------------------------------------------------
diff --git a/be/src/runtime/types.h b/be/src/runtime/types.h
index f265705..f2db3bd 100644
--- a/be/src/runtime/types.h
+++ b/be/src/runtime/types.h
@@ -147,6 +147,8 @@ struct ColumnType {
     return result;
   }
 
+  static std::vector<ColumnType> FromThrift(const std::vector<TColumnType>& 
ttypes);
+
   bool operator==(const ColumnType& o) const {
     if (type != o.type) return false;
     if (children != o.children) return false;

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/testutil/test-udas.cc
----------------------------------------------------------------------
diff --git a/be/src/testutil/test-udas.cc b/be/src/testutil/test-udas.cc
index 0097500..549f2f0 100644
--- a/be/src/testutil/test-udas.cc
+++ b/be/src/testutil/test-udas.cc
@@ -17,6 +17,9 @@
 
 #include "testutil/test-udas.h"
 
+#include <assert.h>
+
+// Don't include Impala internal headers - real UDAs won't include them.
 #include <udf/udf.h>
 
 using namespace impala_udf;
@@ -48,19 +51,95 @@ void Agg(FunctionContext*, const StringVal&, const 
DoubleVal&, StringVal*) {}
 void AggInit(FunctionContext*, StringVal*){}
 void AggMerge(FunctionContext*, const StringVal&, StringVal*) {}
 StringVal AggSerialize(FunctionContext*, const StringVal& v) { return v;}
-StringVal AggFinalize(FunctionContext*, const StringVal& v) { return v;}
-
+StringVal AggFinalize(FunctionContext*, const StringVal& v) {
+  return v;
+}
 
-// Defines AggIntermediate(int) returns BIGINT intermediate CHAR(10)
-// TODO: StringVal should be replaced with BufferVal in Impala 2.0
-void AggIntermediate(FunctionContext*, const IntVal&, StringVal*) {}
-void AggIntermediateUpdate(FunctionContext*, const IntVal&, StringVal*) {}
-void AggIntermediateInit(FunctionContext*, StringVal*) {}
-void AggIntermediateMerge(FunctionContext*, const StringVal&, StringVal*) {}
-BigIntVal AggIntermediateFinalize(FunctionContext*, const StringVal&) {
+// Defines AggIntermediate(int) returns BIGINT intermediate STRING
+void AggIntermediate(FunctionContext* context, const IntVal&, StringVal*) {}
+void AggIntermediateUpdate(FunctionContext* context, const IntVal&, 
StringVal*) {
+  assert(context->GetNumArgs() == 1);
+  assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING);
+  assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
+}
+void AggIntermediateInit(FunctionContext* context, StringVal*) {
+  assert(context->GetNumArgs() == 1);
+  assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING);
+  assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
+}
+void AggIntermediateMerge(FunctionContext* context, const StringVal&, 
StringVal*) {
+  assert(context->GetNumArgs() == 1);
+  assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING);
+  assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
+}
+BigIntVal AggIntermediateFinalize(FunctionContext* context, const StringVal&) {
+  assert(context->GetNumArgs() == 1);
+  assert(context->GetArgType(0)->type == FunctionContext::TYPE_INT);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_STRING);
+  assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
   return BigIntVal::null();
 }
 
+// Defines AggDecimalIntermediate(DECIMAL(1,2), INT) returns DECIMAL(5,6)
+// intermediate DECIMAL(3,4)
+// Useful to test that type parameters are plumbed through.
+void AggDecimalIntermediateUpdate(FunctionContext* context, const DecimalVal&, 
const IntVal&, DecimalVal*) {
+  assert(context->GetNumArgs() == 2);
+  assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL);
+  assert(context->GetArgType(0)->precision == 2);
+  assert(context->GetArgType(0)->scale == 1);
+  assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL);
+  assert(context->GetIntermediateType().precision == 4);
+  assert(context->GetIntermediateType().scale == 3);
+  assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL);
+  assert(context->GetReturnType().precision == 6);
+  assert(context->GetReturnType().scale == 5);
+}
+void AggDecimalIntermediateInit(FunctionContext* context, DecimalVal*) {
+  assert(context->GetNumArgs() == 2);
+  assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL);
+  assert(context->GetArgType(0)->precision == 2);
+  assert(context->GetArgType(0)->scale == 1);
+  assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL);
+  assert(context->GetIntermediateType().precision == 4);
+  assert(context->GetIntermediateType().scale == 3);
+  assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL);
+  assert(context->GetReturnType().precision == 6);
+  assert(context->GetReturnType().scale == 5);
+}
+void AggDecimalIntermediateMerge(FunctionContext* context, const DecimalVal&, 
DecimalVal*) {
+  assert(context->GetNumArgs() == 2);
+  assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL);
+  assert(context->GetArgType(0)->precision == 2);
+  assert(context->GetArgType(0)->scale == 1);
+  assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL);
+  assert(context->GetIntermediateType().precision == 4);
+  assert(context->GetIntermediateType().scale == 3);
+  assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL);
+  assert(context->GetReturnType().precision == 6);
+  assert(context->GetReturnType().scale == 5);
+}
+DecimalVal AggDecimalIntermediateFinalize(FunctionContext* context, const 
DecimalVal&) {
+  assert(context->GetNumArgs() == 2);
+  assert(context->GetArgType(0)->type == FunctionContext::TYPE_DECIMAL);
+  assert(context->GetArgType(0)->precision == 2);
+  assert(context->GetArgType(0)->scale == 1);
+  assert(context->GetArgType(1)->type == FunctionContext::TYPE_INT);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_DECIMAL);
+  assert(context->GetIntermediateType().precision == 4);
+  assert(context->GetIntermediateType().scale == 3);
+  assert(context->GetReturnType().type == FunctionContext::TYPE_DECIMAL);
+  assert(context->GetReturnType().precision == 6);
+  assert(context->GetReturnType().scale == 5);
+  return DecimalVal::null();
+}
+
 // Defines MemTest(bigint) return bigint
 // "Allocates" the specified number of bytes in the update function and frees 
them in the
 // serialize function. Useful for testing mem limits.
@@ -99,22 +178,57 @@ BigIntVal MemTestFinalize(FunctionContext* context, const 
BigIntVal& total) {
 // Defines aggregate function for testing different intermediate/output types 
that
 // computes the truncated bigint sum of many floats.
 void TruncSumInit(FunctionContext* context, DoubleVal* total) {
+  // Arg types should be logical input types of UDA.
+  assert(context->GetNumArgs() == 1);
+  assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE);
+  assert(context->GetArgType(-1) == nullptr);
+  assert(context->GetArgType(1) == nullptr);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE);
+  assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
   *total = DoubleVal(0);
 }
 
 void TruncSumUpdate(FunctionContext* context, const DoubleVal& val, DoubleVal* 
total) {
+  // Arg types should be logical input types of UDA.
+  assert(context->GetNumArgs() == 1);
+  assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE);
+  assert(context->GetArgType(-1) == nullptr);
+  assert(context->GetArgType(1) == nullptr);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE);
+  assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
   total->val += val.val;
 }
 
 void TruncSumMerge(FunctionContext* context, const DoubleVal& src, DoubleVal* 
dst) {
+  // Arg types should be logical input types of UDA.
+  assert(context->GetNumArgs() == 1);
+  assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE);
+  assert(context->GetArgType(-1) == nullptr);
+  assert(context->GetArgType(1) == nullptr);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE);
+  assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
   dst->val += src.val;
 }
 
 const DoubleVal TruncSumSerialize(FunctionContext* context, const DoubleVal& 
total) {
+  // Arg types should be logical input types of UDA.
+  assert(context->GetNumArgs() == 1);
+  assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE);
+  assert(context->GetArgType(-1) == nullptr);
+  assert(context->GetArgType(1) == nullptr);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE);
+  assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
   return total;
 }
 
 BigIntVal TruncSumFinalize(FunctionContext* context, const DoubleVal& total) {
+  // Arg types should be logical input types of UDA.
+  assert(context->GetNumArgs() == 1);
+  assert(context->GetArgType(0)->type == FunctionContext::TYPE_DOUBLE);
+  assert(context->GetArgType(-1) == nullptr);
+  assert(context->GetArgType(1) == nullptr);
+  assert(context->GetIntermediateType().type == FunctionContext::TYPE_DOUBLE);
+  assert(context->GetReturnType().type == FunctionContext::TYPE_BIGINT);
   return BigIntVal(static_cast<int64_t>(total.val));
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/testutil/test-udas.h
----------------------------------------------------------------------
diff --git a/be/src/testutil/test-udas.h b/be/src/testutil/test-udas.h
index 8cd38ec..57b0f06 100644
--- a/be/src/testutil/test-udas.h
+++ b/be/src/testutil/test-udas.h
@@ -18,6 +18,7 @@
 #ifndef IMPALA_UDF_TEST_UDAS_H
 #define IMPALA_UDF_TEST_UDAS_H
 
+// Don't include Impala internal headers - real UDAs won't include them.
 #include "udf/udf.h"
 
 using namespace impala_udf;

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/udf/udf-internal.h
----------------------------------------------------------------------
diff --git a/be/src/udf/udf-internal.h b/be/src/udf/udf-internal.h
index bf3032e..96d0fc7 100644
--- a/be/src/udf/udf-internal.h
+++ b/be/src/udf/udf-internal.h
@@ -47,6 +47,12 @@ class RuntimeState;
 /// This class actually implements the interface of FunctionContext. This is 
split to
 /// hide the details from the external header.
 /// Note: The actual user code does not include this file.
+///
+/// Exprs (e.g. UDFs and UDAs) require a FunctionContext to store state 
related to
+/// evaluation of the expression. Each FunctionContext is associated with a 
backend Expr
+/// or AggFnEvaluator, which is derived from a TExprNode generated by the 
Impala frontend.
+/// FunctionContexts are allocated and managed by ExprContext. Exprs shouldn't 
try to
+/// create FunctionContext themselves.
 class FunctionContextImpl {
  public:
   /// Create a FunctionContext for a UDF. Caller is responsible for deleting 
it.

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/udf/udf-ir.cc
----------------------------------------------------------------------
diff --git a/be/src/udf/udf-ir.cc b/be/src/udf/udf-ir.cc
index 24773f0..c12133c 100644
--- a/be/src/udf/udf-ir.cc
+++ b/be/src/udf/udf-ir.cc
@@ -34,6 +34,10 @@ int FunctionContext::GetNumArgs() const {
   return impl_->arg_types_.size();
 }
 
+const FunctionContext::TypeDesc& FunctionContext::GetIntermediateType() const {
+  return impl_->intermediate_type_;
+}
+
 const FunctionContext::TypeDesc& FunctionContext::GetReturnType() const {
   return impl_->return_type_;
 }

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/be/src/udf/udf.h
----------------------------------------------------------------------
diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h
index 461482c..4fdca2d 100644
--- a/be/src/udf/udf.h
+++ b/be/src/udf/udf.h
@@ -189,21 +189,25 @@ class FunctionContext {
   const TypeDesc& GetIntermediateType() const;
 
   /// Returns the number of arguments to this function (not including the 
FunctionContext*
-  /// argument).
+  /// argument or the output of a UDA).
+  /// For UDAs, returns the number of logical arguments of the aggregate 
function, not
+  /// the number of arguments of the C++ function being executed.
   int GetNumArgs() const;
 
   /// Returns the type information for the arg_idx-th argument (0-indexed, not 
including
   /// the FunctionContext* argument). Returns NULL if arg_idx is invalid.
+  /// For UDAs, returns the logical argument types of the aggregate function, 
not the
+  /// argument types of the C++ function being executed.
   const TypeDesc* GetArgType(int arg_idx) const;
 
-  /// Returns true if the arg_idx-th input argument (0 indexed, not including 
the
-  /// FunctionContext* argument) is a constant (e.g. 5, "string", 1 + 1).
+  /// Returns true if the arg_idx-th input argument (indexed in the same way as
+  /// GetArgType()) is a constant (e.g. 5, "string", 1 + 1).
   bool IsArgConstant(int arg_idx) const;
 
-  /// Returns a pointer to the value of the arg_idx-th input argument (0 
indexed, not
-  /// including the FunctionContext* argument). Returns NULL if the argument 
is not
-  /// constant. This function can be used to obtain user-specified constants 
in a UDF's
-  /// Init() or Close() functions.
+  /// Returns a pointer to the value of the arg_idx-th input argument (indexed 
in the
+  /// same way as GetArgType()). Returns NULL if the argument is not constant. 
This
+  /// function can be used to obtain user-specified constants in a UDF's 
Init() or
+  /// Close() functions.
   AnyVal* GetConstantArg(int arg_idx) const;
 
   /// TODO: Do we need to add arbitrary key/value metadata. This would be 
plumbed

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/common/thrift/Exprs.thrift
----------------------------------------------------------------------
diff --git a/common/thrift/Exprs.thrift b/common/thrift/Exprs.thrift
index 33b859a..fc0f4ee 100644
--- a/common/thrift/Exprs.thrift
+++ b/common/thrift/Exprs.thrift
@@ -112,9 +112,14 @@ struct TStringLiteral {
   1: required string value;
 }
 
+// Additional information for aggregate functions.
 struct TAggregateExpr {
   // Indicates whether this expr is the merge() of an aggregation.
   1: required bool is_merge_agg
+
+  // The types of the input arguments to the aggregate function. May differ 
from the
+  // input expr types if this is the merge() of an aggregation.
+  2: required list<Types.TColumnType> arg_types;
 }
 
 // This is essentially a union over the subclasses of Expr.

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java
----------------------------------------------------------------------
diff --git a/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java 
b/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java
index ec88aae..5dbde32 100644
--- a/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java
+++ b/fe/src/main/java/org/apache/impala/analysis/AggregateInfo.java
@@ -673,7 +673,8 @@ public class AggregateInfo extends AggregateInfoBase {
    * materialized slots of the output tuple corresponds to the number of 
materialized
    * aggregate functions plus the number of grouping exprs. Also checks that 
the return
    * types of the aggregate and grouping exprs correspond to the slots in the 
output
-   * tuple.
+   * tuple and that the input types stored in the merge aggregation are 
consistent
+   * with the input exprs.
    */
   public void checkConsistency() {
     ArrayList<SlotDescriptor> slots = outputTupleDesc_.getSlots();
@@ -707,6 +708,13 @@ public class AggregateInfo extends AggregateInfoBase {
               slotType.toString()));
       ++slotIdx;
     }
+    if (mergeAggInfo_ != null) {
+      // Check that the argument types in mergeAggInfo_ are consistent with 
input exprs.
+      for (int i = 0; i < aggregateExprs_.size(); ++i) {
+        FunctionCallExpr mergeAggExpr = mergeAggInfo_.aggregateExprs_.get(i);
+        mergeAggExpr.validateMergeAggFn(aggregateExprs_.get(i));
+      }
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java
----------------------------------------------------------------------
diff --git a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java 
b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java
index 15af25f..4d7dca8 100644
--- a/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java
+++ b/fe/src/main/java/org/apache/impala/analysis/FunctionCallExpr.java
@@ -30,6 +30,7 @@ import org.apache.impala.catalog.Type;
 import org.apache.impala.common.AnalysisException;
 import org.apache.impala.common.TreeNode;
 import org.apache.impala.thrift.TAggregateExpr;
+import org.apache.impala.thrift.TColumnType;
 import org.apache.impala.thrift.TExprNode;
 import org.apache.impala.thrift.TExprNodeType;
 import org.apache.impala.thrift.TFunctionBinaryType;
@@ -45,10 +46,12 @@ public class FunctionCallExpr extends Expr {
   private boolean isAnalyticFnCall_ = false;
   private boolean isInternalFnCall_ = false;
 
-  // Indicates whether this is a merge aggregation function that should use 
the merge
-  // instead of the update symbol. This flag also affects the behavior of
-  // resetAnalysisState() which is used during expr substitution.
-  private final boolean isMergeAggFn_;
+  // Non-null iff this is an aggregation function that executes the Merge() 
step. This
+  // is an analyzed clone of the FunctionCallExpr that executes the Update() 
function
+  // feeding into this Merge(). This is stored so that we can access the types 
of the
+  // original input argument exprs. Note that the nullness affects the 
behaviour of
+  // resetAnalysisState(), which is used during expr substitution.
+  private final FunctionCallExpr mergeAggInputFn_;
 
   // Printed in toSqlImpl(), if set. Used for merge agg fns.
   private String label_;
@@ -62,15 +65,16 @@ public class FunctionCallExpr extends Expr {
   }
 
   public FunctionCallExpr(FunctionName fnName, FunctionParams params) {
-    this(fnName, params, false);
+    this(fnName, params, null);
   }
 
-  private FunctionCallExpr(
-      FunctionName fnName, FunctionParams params, boolean isMergeAggFn) {
+  private FunctionCallExpr(FunctionName fnName, FunctionParams params,
+      FunctionCallExpr mergeAggInputFn) {
     super();
     fnName_ = fnName;
     params_ = params;
-    isMergeAggFn_ = isMergeAggFn;
+    mergeAggInputFn_ =
+        mergeAggInputFn == null ? null : 
(FunctionCallExpr)mergeAggInputFn.clone();
     if (params.exprs() != null) children_ = 
Lists.newArrayList(params_.exprs());
   }
 
@@ -99,12 +103,12 @@ public class FunctionCallExpr extends Expr {
     Preconditions.checkState(agg.isAnalyzed());
     Preconditions.checkState(agg.isAggregateFunction());
     FunctionCallExpr result = new FunctionCallExpr(
-        agg.fnName_, new FunctionParams(false, params), true);
+        agg.fnName_, new FunctionParams(false, params), agg);
     // Inherit the function object from 'agg'.
     result.fn_ = agg.fn_;
     result.type_ = agg.type_;
     // Set an explicit label based on the input agg.
-    if (agg.isMergeAggFn_) {
+    if (agg.isMergeAggFn()) {
       result.label_ = agg.label_;
     } else {
       // fn(input) becomes fn:merge(input).
@@ -123,7 +127,8 @@ public class FunctionCallExpr extends Expr {
     fnName_ = other.fnName_;
     isAnalyticFnCall_ = other.isAnalyticFnCall_;
     isInternalFnCall_ = other.isInternalFnCall_;
-    isMergeAggFn_ = other.isMergeAggFn_;
+    mergeAggInputFn_ =
+        other.mergeAggInputFn_ == null ? null : 
(FunctionCallExpr)other.mergeAggInputFn_.clone();
     // Clone the params in a way that keeps the children_ and the 
params.exprs()
     // in sync. The children have already been cloned in the super c'tor.
     if (other.params_.isStar()) {
@@ -135,7 +140,7 @@ public class FunctionCallExpr extends Expr {
     label_ = other.label_;
   }
 
-  public boolean isMergeAggFn() { return isMergeAggFn_; }
+  public boolean isMergeAggFn() { return mergeAggInputFn_ != null; }
 
   @Override
   public void resetAnalysisState() {
@@ -144,7 +149,7 @@ public class FunctionCallExpr extends Expr {
     // intermediate agg type is not the same as the output type. Preserve the 
original
     // fn_ such that analyze() hits the special-case code for merge agg fns 
that
     // handles this case.
-    if (!isMergeAggFn_) fn_ = null;
+    if (!isMergeAggFn()) fn_ = null;
   }
 
   @Override
@@ -160,7 +165,7 @@ public class FunctionCallExpr extends Expr {
   public String toSqlImpl() {
     if (label_ != null) return label_;
     // Merge agg fns should have an explicit label.
-    Preconditions.checkState(!isMergeAggFn_);
+    Preconditions.checkState(!isMergeAggFn());
     StringBuilder sb = new StringBuilder();
     sb.append(fnName_).append("(");
     if (params_.isStar()) sb.append("*");
@@ -226,7 +231,12 @@ public class FunctionCallExpr extends Expr {
   protected void toThrift(TExprNode msg) {
     if (isAggregateFunction() || isAnalyticFnCall_) {
       msg.node_type = TExprNodeType.AGGREGATE_EXPR;
-      if (!isAnalyticFnCall_) msg.setAgg_expr(new 
TAggregateExpr(isMergeAggFn_));
+      List<TColumnType> aggFnArgTypes = Lists.newArrayList();
+      FunctionCallExpr inputAggFn = isMergeAggFn() ? mergeAggInputFn_ : this;
+      for (Expr child: inputAggFn.children_) {
+        aggFnArgTypes.add(child.getType().toThrift());
+      }
+      msg.setAgg_expr(new TAggregateExpr(isMergeAggFn(), aggFnArgTypes));
     } else {
       msg.node_type = TExprNodeType.FUNCTION_CALL;
     }
@@ -383,7 +393,7 @@ public class FunctionCallExpr extends Expr {
   protected void analyzeImpl(Analyzer analyzer) throws AnalysisException {
     fnName_.analyze(analyzer);
 
-    if (isMergeAggFn_) {
+    if (isMergeAggFn()) {
       // This is the function call expr after splitting up to a merge 
aggregation.
       // The function has already been analyzed so just do the minimal sanity
       // check here.
@@ -524,6 +534,25 @@ public class FunctionCallExpr extends Expr {
     }
   }
 
+  /**
+   * Validate that the internal state, specifically types, is consistent 
between the
+   * the Update() and Merge() aggregate functions.
+   */
+  void validateMergeAggFn(FunctionCallExpr inputAggFn) {
+    Preconditions.checkState(isMergeAggFn());
+    List<Expr> copiedInputExprs = mergeAggInputFn_.getChildren();
+    List<Expr> inputExprs = inputAggFn.getChildren();
+    Preconditions.checkState(copiedInputExprs.size() == inputExprs.size());
+    for (int i = 0; i < inputExprs.size(); ++i) {
+      Type copiedInputType = copiedInputExprs.get(i).getType();
+      Type inputType = inputExprs.get(i).getType();
+      Preconditions.checkState(copiedInputType.equals(inputType),
+          String.format("Copied expr %s arg type %s differs from input expr 
type %s " +
+            "in original expr %s", toSql(), copiedInputType.toSql(),
+            inputType.toSql(), inputAggFn.toSql()));
+    }
+  }
+
   @Override
   public Expr clone() { return new FunctionCallExpr(this); }
 }

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/testdata/workloads/functional-query/queries/QueryTest/uda.test
----------------------------------------------------------------------
diff --git a/testdata/workloads/functional-query/queries/QueryTest/uda.test 
b/testdata/workloads/functional-query/queries/QueryTest/uda.test
index 05e24e7..3a9bbbe 100644
--- a/testdata/workloads/functional-query/queries/QueryTest/uda.test
+++ b/testdata/workloads/functional-query/queries/QueryTest/uda.test
@@ -69,3 +69,22 @@ from functional.alltypesagg
 ---- TYPES
 bigint,bigint
 ====
+---- QUERY
+# Test that all types are exposed via the FunctionContext correctly.
+# This relies on asserts in the UDA funciton
+select agg_intermediate(int_col), count(*)
+from functional.alltypesagg
+---- RESULTS
+NULL,11000
+---- TYPES
+bigint,bigint
+====
+---- QUERY
+# Test that all types are exposed via the FunctionContext correctly.
+# This relies on asserts in the UDA funciton
+select agg_decimal_intermediate(cast(d1 as decimal(2,1)), 2), count(*)
+from functional.decimal_tbl
+---- RESULTS
+NULL,5
+---- TYPES
+decimal,bigint

http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/d2d3f4c1/tests/query_test/test_udfs.py
----------------------------------------------------------------------
diff --git a/tests/query_test/test_udfs.py b/tests/query_test/test_udfs.py
index 746cf88..56ce233 100644
--- a/tests/query_test/test_udfs.py
+++ b/tests/query_test/test_udfs.py
@@ -93,6 +93,16 @@ update_fn='ToggleNullUpdate' merge_fn='ToggleNullMerge';
 create aggregate function {database}.count_nulls(bigint)
 returns bigint location '{location}'
 update_fn='CountNullsUpdate' merge_fn='CountNullsMerge';
+
+create aggregate function {database}.agg_intermediate(int)
+returns bigint intermediate string location '{location}'
+init_fn='AggIntermediateInit' update_fn='AggIntermediateUpdate'
+merge_fn='AggIntermediateMerge' finalize_fn='AggIntermediateFinalize';
+
+create aggregate function {database}.agg_decimal_intermediate(decimal(2,1), 
int)
+returns decimal(6,5) intermediate decimal(4,3) location '{location}'
+init_fn='AggDecimalIntermediateInit' update_fn='AggDecimalIntermediateUpdate'
+merge_fn='AggDecimalIntermediateMerge' 
finalize_fn='AggDecimalIntermediateFinalize';
 """
 
   # Create test UDF functions in {database} from library {location}

Reply via email to