IMPALA-2020, IMPALA-4809: Codegen support for DECIMAL_V2 Currently, codegen supports converting type attributes (e.g. decimal type's precision and scale, type's size) obtained via calls to FunctionContextImpl::GetFnAttr() (previously Expr::GetConstantInt()) in cross-compiled code to runtime constants. This change extends this support for the query option DECIMAL_V2.
To test this change, this change also handles a subset of IMPALA-2020: casting between decimal values is updated to support rounding (instead of truncation) when decimal_v2 is true. This change also cleans up the existing code by moving the codegen logic Expr::InlineConstant() to the codegen module and the type related logic in Expr::GetConstantInt() to FunctionContextImpl. Change-Id: I2434d240f65b81389b8a8ba027f980a0e1d1f981 Reviewed-on: http://gerrit.cloudera.org:8080/5950 Reviewed-by: Michael Ho <[email protected]> Tested-by: Impala Public Jenkins Project: http://git-wip-us.apache.org/repos/asf/incubator-impala/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-impala/commit/f982c3f7 Tree: http://git-wip-us.apache.org/repos/asf/incubator-impala/tree/f982c3f7 Diff: http://git-wip-us.apache.org/repos/asf/incubator-impala/diff/f982c3f7 Branch: refs/heads/master Commit: f982c3f76e762f646b52a01659796b4edcfbc1ad Parents: a78726d Author: Michael Ho <[email protected]> Authored: Fri Feb 3 20:55:18 2017 -0800 Committer: Impala Public Jenkins <[email protected]> Committed: Sat Feb 11 07:07:45 2017 +0000 ---------------------------------------------------------------------- be/src/benchmarks/hash-benchmark.cc | 18 +- be/src/codegen/llvm-codegen-test.cc | 60 +++-- be/src/codegen/llvm-codegen.cc | 98 +++++++-- be/src/codegen/llvm-codegen.h | 48 ++-- be/src/exec/partitioned-aggregation-node.cc | 8 +- be/src/exec/partitioned-aggregation-node.h | 2 +- be/src/exprs/agg-fn-evaluator.cc | 2 +- be/src/exprs/aggregate-functions-ir.cc | 12 +- be/src/exprs/conditional-functions-ir.cc | 27 ++- be/src/exprs/decimal-functions-ir.cc | 78 +++---- be/src/exprs/decimal-operators-ir.cc | 219 ++++++++++--------- be/src/exprs/decimal-operators.h | 2 +- be/src/exprs/expr-codegen-test.cc | 152 ++++++++----- be/src/exprs/expr-test.cc | 12 +- be/src/exprs/expr.cc | 86 -------- be/src/exprs/expr.h | 59 ----- be/src/exprs/math-functions-ir.cc | 4 +- be/src/exprs/scalar-fn-call.cc | 4 +- be/src/runtime/decimal-value.inline.h | 3 +- be/src/runtime/lib-cache.cc | 9 +- be/src/runtime/runtime-state.cc | 2 +- be/src/runtime/runtime-state.h | 1 + be/src/service/frontend.cc | 2 +- be/src/udf/udf-internal.h | 47 +++- be/src/udf/udf-test-harness.cc | 5 +- be/src/udf/udf-test-harness.h | 4 +- be/src/udf/udf.cc | 60 ++++- .../queries/QueryTest/decimal.test | 44 ++++ tests/query_test/test_decimal_casting.py | 70 +++--- 29 files changed, 646 insertions(+), 492 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/benchmarks/hash-benchmark.cc ---------------------------------------------------------------------- diff --git a/be/src/benchmarks/hash-benchmark.cc b/be/src/benchmarks/hash-benchmark.cc index 125310b..fb37bed 100644 --- a/be/src/benchmarks/hash-benchmark.cc +++ b/be/src/benchmarks/hash-benchmark.cc @@ -26,6 +26,7 @@ #include "experiments/data-provider.h" #include "runtime/mem-tracker.h" #include "runtime/string-value.h" +#include "runtime/test-env.h" #include "util/benchmark.h" #include "util/cpu-info.h" #include "util/hash-util.h" @@ -419,17 +420,24 @@ int main(int argc, char **argv) { const int NUM_ROWS = 1024; - ObjectPool obj_pool; + Status status; + RuntimeState* state; + TestEnv test_env; + status = test_env.CreateQueryState(0, 0, 0, nullptr, &state); + if (!status.ok()) { + cout << "Could not create RuntimeState"; + return -1; + } + MemTracker tracker; MemPool mem_pool(&tracker); - RuntimeProfile int_profile(&obj_pool, "IntGen"); - RuntimeProfile mixed_profile(&obj_pool, "MixedGen"); + RuntimeProfile int_profile(state->obj_pool(), "IntGen"); + RuntimeProfile mixed_profile(state->obj_pool(), "MixedGen"); DataProvider int_provider(&mem_pool, &int_profile); DataProvider mixed_provider(&mem_pool, &mixed_profile); - Status status; scoped_ptr<LlvmCodeGen> codegen; - status = LlvmCodeGen::CreateImpalaCodegen(&obj_pool, NULL, "test", &codegen); + status = LlvmCodeGen::CreateImpalaCodegen(state, NULL, "test", &codegen); if (!status.ok()) { cout << "Could not start codegen."; return -1; http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/codegen/llvm-codegen-test.cc ---------------------------------------------------------------------- diff --git a/be/src/codegen/llvm-codegen-test.cc b/be/src/codegen/llvm-codegen-test.cc index 412426f..59c66f4 100644 --- a/be/src/codegen/llvm-codegen-test.cc +++ b/be/src/codegen/llvm-codegen-test.cc @@ -24,6 +24,8 @@ #include "common/init.h" #include "common/object-pool.h" #include "runtime/string-value.h" +#include "runtime/test-env.h" +#include "service/fe-support.h" #include "util/cpu-info.h" #include "util/hash-util.h" #include "util/path-builder.h" @@ -38,13 +40,27 @@ namespace impala { class LlvmCodeGenTest : public testing:: Test { protected: + scoped_ptr<TestEnv> test_env_; + RuntimeState* runtime_state_; + + virtual void SetUp() { + test_env_.reset(new TestEnv()); + EXPECT_OK(test_env_->CreateQueryState(0, 1, 8 * 1024 * 1024, nullptr, + &runtime_state_)); + } + + virtual void TearDown() { + runtime_state_ = NULL; + test_env_.reset(); + } + static void LifetimeTest() { ObjectPool pool; Status status; for (int i = 0; i < 10; ++i) { - LlvmCodeGen object1(&pool, NULL, "Test"); - LlvmCodeGen object2(&pool, NULL, "Test"); - LlvmCodeGen object3(&pool, NULL, "Test"); + LlvmCodeGen object1(NULL, &pool, NULL, "Test"); + LlvmCodeGen object2(NULL, &pool, NULL, "Test"); + LlvmCodeGen object3(NULL, &pool, NULL, "Test"); ASSERT_OK(object1.Init(unique_ptr<Module>(new Module("Test", object1.context())))); ASSERT_OK(object2.Init(unique_ptr<Module>(new Module("Test", object2.context())))); @@ -53,22 +69,12 @@ class LlvmCodeGenTest : public testing:: Test { } // Wrapper to call private test-only methods on LlvmCodeGen object - static Status CreateFromFile( - ObjectPool* pool, const string& filename, scoped_ptr<LlvmCodeGen>* codegen) { - RETURN_IF_ERROR(LlvmCodeGen::CreateFromFile(pool, NULL, filename, "test", codegen)); + Status CreateFromFile(const string& filename, scoped_ptr<LlvmCodeGen>* codegen) { + RETURN_IF_ERROR(LlvmCodeGen::CreateFromFile(runtime_state_, + runtime_state_->obj_pool(), NULL, filename, "test", codegen)); return (*codegen)->MaterializeModule(); } - static LlvmCodeGen* CreateCodegen(ObjectPool* pool) { - LlvmCodeGen* codegen = pool->Add(new LlvmCodeGen(pool, NULL, "Test")); - if (codegen != NULL) { - Status status = - codegen->Init(unique_ptr<Module>(new Module("Test", codegen->context()))); - if (!status.ok()) return NULL; - } - return codegen; - } - static void ClearHashFns(LlvmCodeGen* codegen) { codegen->ClearHashFns(); } @@ -104,11 +110,9 @@ TEST_F(LlvmCodeGenTest, MultithreadedLifetime) { // Test loading a non-existent file TEST_F(LlvmCodeGenTest, BadIRFile) { - ObjectPool pool; string module_file = "NonExistentFile.ir"; scoped_ptr<LlvmCodeGen> codegen; - EXPECT_FALSE( - LlvmCodeGenTest::CreateFromFile(&pool, module_file.c_str(), &codegen).ok()); + EXPECT_FALSE(CreateFromFile(module_file.c_str(), &codegen).ok()); } // IR for the generated linner loop @@ -159,7 +163,6 @@ Function* CodegenInnerLoop(LlvmCodeGen* codegen, int64_t* jitted_counter, int de // 5. Updated the jitted loop in place with another jitted inner loop function // 6. Run the loop and make sure the updated is called. TEST_F(LlvmCodeGenTest, ReplaceFnCall) { - ObjectPool pool; const string loop_call_name("_Z21DefaultImplementationv"); const string loop_name("_Z8TestLoopi"); typedef void (*TestLoopFn)(int); @@ -169,7 +172,7 @@ TEST_F(LlvmCodeGenTest, ReplaceFnCall) { // Part 1: Load the module and make sure everything is loaded correctly. scoped_ptr<LlvmCodeGen> codegen; - ASSERT_OK(LlvmCodeGenTest::CreateFromFile(&pool, module_file.c_str(), &codegen)); + ASSERT_OK(CreateFromFile(module_file.c_str(), &codegen)); EXPECT_TRUE(codegen.get() != NULL); Function* loop_call = codegen->GetFunction(loop_call_name, false); @@ -285,10 +288,8 @@ Function* CodegenStringTest(LlvmCodeGen* codegen) { // struct. Just create a simple StringValue struct and make sure the IR can read it // and modify it. TEST_F(LlvmCodeGenTest, StringValue) { - ObjectPool pool; - scoped_ptr<LlvmCodeGen> codegen; - ASSERT_OK(LlvmCodeGen::CreateImpalaCodegen(&pool, NULL, "test", &codegen)); + ASSERT_OK(LlvmCodeGen::CreateImpalaCodegen(runtime_state_, NULL, "test", &codegen)); EXPECT_TRUE(codegen.get() != NULL); string str("Test"); @@ -328,10 +329,8 @@ TEST_F(LlvmCodeGenTest, StringValue) { // Test calling memcpy intrinsic TEST_F(LlvmCodeGenTest, MemcpyTest) { - ObjectPool pool; - scoped_ptr<LlvmCodeGen> codegen; - ASSERT_OK(LlvmCodeGen::CreateImpalaCodegen(&pool, NULL, "test", &codegen)); + ASSERT_OK(LlvmCodeGen::CreateImpalaCodegen(runtime_state_, NULL, "test", &codegen)); ASSERT_TRUE(codegen.get() != NULL); LlvmCodeGen::FnPrototype prototype(codegen.get(), "MemcpyTest", codegen->void_type()); @@ -367,8 +366,6 @@ TEST_F(LlvmCodeGenTest, MemcpyTest) { // Test codegen for hash TEST_F(LlvmCodeGenTest, HashTest) { - ObjectPool pool; - // Values to compute hash on const char* data1 = "test string"; const char* data2 = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; @@ -378,7 +375,7 @@ TEST_F(LlvmCodeGenTest, HashTest) { // Loop to test both the sse4 on/off paths for (int i = 0; i < 2; ++i) { scoped_ptr<LlvmCodeGen> codegen; - ASSERT_OK(LlvmCodeGen::CreateImpalaCodegen(&pool, NULL, "test", &codegen)); + ASSERT_OK(LlvmCodeGen::CreateImpalaCodegen(runtime_state_, NULL, "test", &codegen)); ASSERT_TRUE(codegen.get() != NULL); Value* llvm_data1 = @@ -452,7 +449,8 @@ TEST_F(LlvmCodeGenTest, HashTest) { int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); - impala::InitCommonRuntime(argc, argv, false, impala::TestInfo::BE_TEST); + impala::InitCommonRuntime(argc, argv, true, impala::TestInfo::BE_TEST); + impala::InitFeSupport(false); impala::LlvmCodeGen::InitializeLlvm(); return RUN_ALL_TESTS(); } http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/codegen/llvm-codegen.cc ---------------------------------------------------------------------- diff --git a/be/src/codegen/llvm-codegen.cc b/be/src/codegen/llvm-codegen.cc index 6cf43f6..a548654 100644 --- a/be/src/codegen/llvm-codegen.cc +++ b/be/src/codegen/llvm-codegen.cc @@ -58,12 +58,14 @@ #include "codegen/instruction-counter.h" #include "codegen/mcjit-mem-mgr.h" #include "common/logging.h" +#include "exprs/anyval-util.h" #include "impala-ir/impala-ir-names.h" #include "runtime/descriptors.h" #include "runtime/hdfs-fs-cache.h" #include "runtime/lib-cache.h" #include "runtime/mem-pool.h" #include "runtime/mem-tracker.h" +#include "runtime/runtime-state.h" #include "runtime/string-value.h" #include "runtime/timestamp-value.h" #include "util/cpu-info.h" @@ -170,7 +172,8 @@ Status LlvmCodeGen::InitializeLlvm(bool load_backend) { ObjectPool init_pool; scoped_ptr<LlvmCodeGen> init_codegen; - RETURN_IF_ERROR(LlvmCodeGen::CreateFromMemory(&init_pool, NULL, "init", &init_codegen)); + RETURN_IF_ERROR(LlvmCodeGen::CreateFromMemory( + nullptr, &init_pool, nullptr, "init", &init_codegen)); // LLVM will construct "use" lists only when the entire module is materialized. RETURN_IF_ERROR(init_codegen->MaterializeModule()); @@ -209,9 +212,10 @@ Status LlvmCodeGen::InitializeLlvm(bool load_backend) { return Status::OK(); } -LlvmCodeGen::LlvmCodeGen( - ObjectPool* pool, MemTracker* parent_mem_tracker, const string& id) - : id_(id), +LlvmCodeGen::LlvmCodeGen(RuntimeState* state, ObjectPool* pool, + MemTracker* parent_mem_tracker, const string& id) + : state_(state), + id_(id), profile_(pool, "CodeGen"), mem_tracker_(new MemTracker(&profile_, -1, "CodeGen", parent_mem_tracker)), optimizations_enabled_(false), @@ -233,9 +237,10 @@ LlvmCodeGen::LlvmCodeGen( num_instructions_ = ADD_COUNTER(&profile_, "NumInstructions", TUnit::UNIT); } -Status LlvmCodeGen::CreateFromFile(ObjectPool* pool, MemTracker* parent_mem_tracker, - const string& file, const string& id, scoped_ptr<LlvmCodeGen>* codegen) { - codegen->reset(new LlvmCodeGen(pool, parent_mem_tracker, id)); +Status LlvmCodeGen::CreateFromFile(RuntimeState* state, ObjectPool* pool, + MemTracker* parent_mem_tracker, const string& file, const string& id, + scoped_ptr<LlvmCodeGen>* codegen) { + codegen->reset(new LlvmCodeGen(state, pool, parent_mem_tracker, id)); SCOPED_TIMER((*codegen)->profile_.total_time_counter()); unique_ptr<Module> loaded_module; @@ -244,9 +249,9 @@ Status LlvmCodeGen::CreateFromFile(ObjectPool* pool, MemTracker* parent_mem_trac return (*codegen)->Init(std::move(loaded_module)); } -Status LlvmCodeGen::CreateFromMemory(ObjectPool* pool, MemTracker* parent_mem_tracker, - const string& id, scoped_ptr<LlvmCodeGen>* codegen) { - codegen->reset(new LlvmCodeGen(pool, parent_mem_tracker, id)); +Status LlvmCodeGen::CreateFromMemory(RuntimeState* state, ObjectPool* pool, + MemTracker* parent_mem_tracker, const string& id, scoped_ptr<LlvmCodeGen>* codegen) { + codegen->reset(new LlvmCodeGen(state, pool, parent_mem_tracker, id)); SCOPED_TIMER((*codegen)->profile_.total_time_counter()); // Select the appropriate IR version. We cannot use LLVM IR with SSE4.2 instructions on @@ -355,9 +360,12 @@ void LlvmCodeGen::StripGlobalCtorsDtors(llvm::Module* module) { if (destructors != NULL) destructors->eraseFromParent(); } -Status LlvmCodeGen::CreateImpalaCodegen(ObjectPool* pool, MemTracker* parent_mem_tracker, - const string& id, scoped_ptr<LlvmCodeGen>* codegen_ret) { - RETURN_IF_ERROR(CreateFromMemory(pool, parent_mem_tracker, id, codegen_ret)); +Status LlvmCodeGen::CreateImpalaCodegen(RuntimeState* state, + MemTracker* parent_mem_tracker, const string& id, + scoped_ptr<LlvmCodeGen>* codegen_ret) { + DCHECK(state != nullptr); + RETURN_IF_ERROR(CreateFromMemory( + state, state->obj_pool(), parent_mem_tracker, id, codegen_ret)); LlvmCodeGen* codegen = codegen_ret->get(); // Parse module for cross compiled functions and types @@ -681,17 +689,19 @@ Function* LlvmCodeGen::GetFunction(IRFunction::Type ir_type, bool clone) { bool LlvmCodeGen::VerifyFunction(Function* fn) { if (is_corrupt_) return false; - // Check that there are no calls to Expr::GetConstant(). These should all have been - // inlined via Expr::InlineConstants(). + // Check that there are no calls to FunctionContextImpl::GetConstFnAttr(). These should all + // have been inlined via InlineConstFnAttrs(). for (inst_iterator iter = inst_begin(fn); iter != inst_end(fn); ++iter) { Instruction* instr = &*iter; if (!isa<CallInst>(instr)) continue; CallInst* call_instr = reinterpret_cast<CallInst*>(instr); Function* called_fn = call_instr->getCalledFunction(); - // look for call to Expr::GetConstant() - if (called_fn != NULL && - called_fn->getName().find(Expr::GET_CONSTANT_INT_SYMBOL_PREFIX) != string::npos) { - LOG(ERROR) << "Found call to Expr::GetConstant*(): " << Print(call_instr); + + // Look for call to FunctionContextImpl::GetConstFnAttr(). + if (called_fn != nullptr && + called_fn->getName() == FunctionContextImpl::GET_CONST_FN_ATTR_SYMBOL) { + LOG(ERROR) << "Found call to FunctionContextImpl::GetConstFnAttr(): " + << Print(call_instr); is_corrupt_ = true; break; } @@ -916,6 +926,47 @@ int LlvmCodeGen::ReplaceCallSitesWithBoolConst(llvm::Function* caller, bool cons return ReplaceCallSitesWithValue(caller, replacement, target_name); } +int LlvmCodeGen::InlineConstFnAttrs(const FunctionContext::TypeDesc& ret_type, + const vector<FunctionContext::TypeDesc>& arg_types, Function* fn) { + int replaced = 0; + for (inst_iterator iter = inst_begin(fn), end = inst_end(fn); iter != end; ) { + // Increment iter now so we don't mess it up modifying the instruction below + Instruction* instr = &*(iter++); + + // Look for call instructions + if (!isa<CallInst>(instr)) continue; + CallInst* call_instr = cast<CallInst>(instr); + Function* called_fn = call_instr->getCalledFunction(); + + // Look for call to FunctionContextImpl::GetConstFnAttr(). + if (called_fn == nullptr || + called_fn->getName() != FunctionContextImpl::GET_CONST_FN_ATTR_SYMBOL) { + continue; + } + + // 't' and 'i' arguments must be constant + ConstantInt* t_arg = dyn_cast<ConstantInt>(call_instr->getArgOperand(1)); + ConstantInt* i_arg = dyn_cast<ConstantInt>(call_instr->getArgOperand(2)); + // This optimization is only applied to built-ins which should have constant args. + DCHECK(t_arg != nullptr) + << "Non-constant 't' argument to FunctionContextImpl::GetConstFnAttr()"; + DCHECK(i_arg != nullptr) + << "Non-constant 'i' argument to FunctionContextImpl::GetConstFnAttr"; + + // Replace the called function with the appropriate constant + FunctionContextImpl::ConstFnAttr t_val = + static_cast<FunctionContextImpl::ConstFnAttr>(t_arg->getSExtValue()); + int i_val = static_cast<int>(i_arg->getSExtValue()); + DCHECK(state_ != nullptr); + // All supported constants are currently integers. + call_instr->replaceAllUsesWith(ConstantInt::get(GetType(TYPE_INT), + FunctionContextImpl::GetConstFnAttr(state_, ret_type, arg_types, t_val, i_val))); + call_instr->eraseFromParent(); + ++replaced; + } + return replaced; +} + void LlvmCodeGen::FindCallSites(Function* caller, const string& target_name, vector<CallInst*>* results) { for (inst_iterator iter = inst_begin(caller); iter != inst_end(caller); ++iter) { @@ -1216,10 +1267,15 @@ void LlvmCodeGen::CodegenDebugTrace(LlvmBuilder* builder, const char* str, builder->CreateCall(printf, calling_args); } -void LlvmCodeGen::GetSymbols(unordered_set<string>* symbols) { - for (const Function& fn: module_->functions()) { +Status LlvmCodeGen::GetSymbols(const string& file, const string& module_id, + unordered_set<string>* symbols) { + ObjectPool pool; + scoped_ptr<LlvmCodeGen> codegen; + RETURN_IF_ERROR(CreateFromFile(nullptr, &pool, nullptr, file, module_id, &codegen)); + for (const Function& fn: codegen->module_->functions()) { if (fn.isMaterializable()) symbols->insert(fn.getName()); } + return Status::OK(); } // TODO: cache this function (e.g. all min(int, int) are identical). http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/codegen/llvm-codegen.h ---------------------------------------------------------------------- diff --git a/be/src/codegen/llvm-codegen.h b/be/src/codegen/llvm-codegen.h index d51faab..058403d 100644 --- a/be/src/codegen/llvm-codegen.h +++ b/be/src/codegen/llvm-codegen.h @@ -153,17 +153,9 @@ class LlvmCodeGen { /// 'codegen' will contain the created object on success. /// 'parent_mem_tracker' - if non-NULL, the CodeGen MemTracker is created under this. /// 'id' is used for outputting the IR module for debugging. - static Status CreateImpalaCodegen(ObjectPool*, MemTracker* parent_mem_tracker, + static Status CreateImpalaCodegen(RuntimeState* state, MemTracker* parent_mem_tracker, const std::string& id, boost::scoped_ptr<LlvmCodeGen>* codegen); - /// Creates a LlvmCodeGen instance initialized with the module bitcode from 'file'. - /// 'codegen' will contain the created object on success. The functions in the module - /// are materialized lazily. Getting a reference to a function via GetFunction() will - /// materialize the function and its callees recursively. - static Status CreateFromFile(ObjectPool*, MemTracker* parent_mem_tracker, - const std::string& file, const std::string& id, - boost::scoped_ptr<LlvmCodeGen>* codegen); - /// Removes all jit compiled dynamically linked functions from the process. ~LlvmCodeGen(); @@ -204,6 +196,7 @@ class LlvmCodeGen { void AddArgument(const NamedVariable& var) { args_.push_back(var); } + void AddArgument(const std::string& name, llvm::Type* type) { args_.push_back(NamedVariable(name, type)); } @@ -333,6 +326,17 @@ class LlvmCodeGen { int ReplaceCallSitesWithValue(llvm::Function* caller, llvm::Value* replacement, const std::string& target_name); + /// This function replaces calls to FunctionContextImpl::GetConstFnAttr() with constants + /// derived from 'return_type', 'arg_types' and the runtime state 'state_'. Please note + /// that this function only replaces call instructions inside 'fn' so to replace the + /// call to FunctionContextImpl::GetConstFnAttr() inside the callee functions, please + /// inline the callee functions (by annotating them with IR_ALWAYS_INLINE). + /// + /// TODO: implement a loop unroller (or use LLVM's) so we can use + /// FunctionContextImpl::GetConstFnAttr() in loops + int InlineConstFnAttrs(const FunctionContext::TypeDesc& return_type, + const std::vector<FunctionContext::TypeDesc>& arg_types, llvm::Function* fn); + /// Returns a copy of fn. The copy is added to the module. llvm::Function* CloneFunction(llvm::Function* fn); @@ -471,8 +475,9 @@ class LlvmCodeGen { llvm::Type* void_type() { return void_type_; } llvm::Type* i128_type() { return llvm::Type::getIntNTy(context(), 128); } - /// Fils in 'symbols' with all the symbols in the module. - void GetSymbols(boost::unordered_set<std::string>* symbols); + /// Load the module temporarily and populate 'symbols' with the symbols in the module. + static Status GetSymbols(const string& file, const string& module_id, + boost::unordered_set<std::string>* symbols); /// Generates function to return min/max(v1, v2) llvm::Function* CodegenMinMax(const ColumnType& type, bool min); @@ -524,19 +529,28 @@ class LlvmCodeGen { static void FindGlobalUsers(llvm::User* val, std::vector<llvm::GlobalObject*>* users); /// Top level codegen object. 'module_id' is used for debugging when outputting the IR. - LlvmCodeGen( - ObjectPool* pool, MemTracker* parent_mem_tracker, const std::string& module_id); + LlvmCodeGen(RuntimeState* state, ObjectPool* pool, MemTracker* parent_mem_tracker, + const std::string& module_id); /// Initializes the jitter and execution engine with the given module. Status Init(std::unique_ptr<llvm::Module> module); - /// Creates a LlvmCodeGen instance initialized with the module bitcode in memory. + /// Creates a LlvmCodeGen instance initialized with the module bitcode from 'file'. /// 'codegen' will contain the created object on success. The functions in the module /// are materialized lazily. Getting a reference to a function via GetFunction() will /// materialize the function and its callees recursively. - static Status CreateFromMemory(ObjectPool* pool, MemTracker* parent_mem_tracker, + static Status CreateFromFile(RuntimeState* state, ObjectPool* pool, + MemTracker* parent_mem_tracker, const std::string& file, const std::string& id, boost::scoped_ptr<LlvmCodeGen>* codegen); + /// Creates a LlvmCodeGen instance initialized with the module bitcode in memory. + /// 'codegen' will contain the created object on success. The functions in the module + /// are materialized lazily. Getting a reference to a function via GetFunction() will + /// materialize the function and its callees recursively. + static Status CreateFromMemory(RuntimeState* state, ObjectPool* pool, + MemTracker* parent_mem_tracker, const std::string& id, + boost::scoped_ptr<LlvmCodeGen>* codegen); + /// Loads an LLVM module from 'file' which is the local path to the LLVM bitcode file. /// The functions in the module are materialized lazily. Getting a reference to the /// function via GetFunction() will materialize the function and its callees @@ -639,6 +653,10 @@ class LlvmCodeGen { /// when a module is loaded to ensure that LLVM can resolve references to them. static boost::unordered_set<std::string> fns_to_always_materialize_; + /// Pointer to the RuntimeState which owns this codegen object. Needed in + /// InlineConstFnAttr() to access the query options. + const RuntimeState* state_; + /// ID used for debugging (can be e.g. the fragment instance ID) std::string id_; http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/exec/partitioned-aggregation-node.cc ---------------------------------------------------------------------- diff --git a/be/src/exec/partitioned-aggregation-node.cc b/be/src/exec/partitioned-aggregation-node.cc index 7c75d01..83232d2 100644 --- a/be/src/exec/partitioned-aggregation-node.cc +++ b/be/src/exec/partitioned-aggregation-node.cc @@ -1681,8 +1681,8 @@ Status PartitionedAggregationNode::CodegenUpdateSlot(LlvmCodeGen* codegen, // Call the UDA to update/merge 'src' into 'dst', with the result stored in // 'updated_dst_val'. CodegenAnyVal updated_dst_val; - RETURN_IF_ERROR(CodegenCallUda( - codegen, &builder, evaluator, agg_fn_ctx_arg, input_vals, dst, &updated_dst_val)); + RETURN_IF_ERROR(CodegenCallUda(codegen, &builder, evaluator, agg_fn_ctx_arg, + input_vals, dst, &updated_dst_val)); result = updated_dst_val.ToNativeValue(); if (slot_desc->is_nullable() && !special_null_handling) { @@ -1717,7 +1717,7 @@ Status PartitionedAggregationNode::CodegenUpdateSlot(LlvmCodeGen* codegen, } Status PartitionedAggregationNode::CodegenCallUda(LlvmCodeGen* codegen, - LlvmBuilder* builder, AggFnEvaluator* evaluator, Value* agg_fn_ctx, + LlvmBuilder* builder, AggFnEvaluator* evaluator, Value* agg_fn_ctx_arg, const vector<CodegenAnyVal>& input_vals, const CodegenAnyVal& dst, CodegenAnyVal* updated_dst_val) { DCHECK_EQ(evaluator->input_expr_ctxs().size(), input_vals.size()); @@ -1727,7 +1727,7 @@ Status PartitionedAggregationNode::CodegenCallUda(LlvmCodeGen* codegen, // Set up arguments for call to UDA, which are the FunctionContext*, followed by // pointers to all input values, followed by a pointer to the destination value. vector<Value*> uda_fn_args; - uda_fn_args.push_back(agg_fn_ctx); + uda_fn_args.push_back(agg_fn_ctx_arg); // Create pointers to input args to pass to uda_fn. We must use the unlowered type, // e.g. IntVal, because the UDA interface expects the values to be passed as const http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/exec/partitioned-aggregation-node.h ---------------------------------------------------------------------- diff --git a/be/src/exec/partitioned-aggregation-node.h b/be/src/exec/partitioned-aggregation-node.h index 54840c7..f26a252 100644 --- a/be/src/exec/partitioned-aggregation-node.h +++ b/be/src/exec/partitioned-aggregation-node.h @@ -651,7 +651,7 @@ class PartitionedAggregationNode : public ExecNode { /// operation is applied. The instruction sequence for the UDA call is inserted at /// the insert position of 'builder'. Status CodegenCallUda(LlvmCodeGen* codegen, LlvmBuilder* builder, - AggFnEvaluator* evaluator, llvm::Value* agg_fn_ctx, + AggFnEvaluator* evaluator, llvm::Value* agg_fn_ctx_arg, const std::vector<CodegenAnyVal>& input_vals, const CodegenAnyVal& dst_val, CodegenAnyVal* updated_dst_val); http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/exprs/agg-fn-evaluator.cc ---------------------------------------------------------------------- diff --git a/be/src/exprs/agg-fn-evaluator.cc b/be/src/exprs/agg-fn-evaluator.cc index b623130..81c0277 100644 --- a/be/src/exprs/agg-fn-evaluator.cc +++ b/be/src/exprs/agg-fn-evaluator.cc @@ -528,7 +528,7 @@ Status AggFnEvaluator::GetUpdateOrMergeFunction(LlvmCodeGen* codegen, Function** if (!(*uda_fn)->isDeclaration()) { // TODO: IMPALA-4785: we should also replace references to GetIntermediateType() // with constants. - Expr::InlineConstants(GetOutputTypeDesc(), arg_type_descs_, codegen, *uda_fn); + codegen->InlineConstFnAttrs(GetOutputTypeDesc(), arg_type_descs_, *uda_fn); *uda_fn = codegen->FinalizeFunction(*uda_fn); if (*uda_fn == NULL) { return Status(TErrorCode::UDF_VERIFY_FAILED, symbol, fn_.hdfs_location); http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/exprs/aggregate-functions-ir.cc ---------------------------------------------------------------------- diff --git a/be/src/exprs/aggregate-functions-ir.cc b/be/src/exprs/aggregate-functions-ir.cc index 8505d43..531598c 100644 --- a/be/src/exprs/aggregate-functions-ir.cc +++ b/be/src/exprs/aggregate-functions-ir.cc @@ -401,7 +401,7 @@ IR_ALWAYS_INLINE void AggregateFunctions::DecimalAvgAddOrRemove(FunctionContext* // Since the src and dst are guaranteed to be the same scale, we can just // do a simple add. int m = remove ? -1 : 1; - switch (Expr::GetConstantInt(*ctx, Expr::ARG_TYPE_SIZE, 0)) { + switch (ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 0)) { case 4: avg->sum.val16 += m * src.val4; break; @@ -511,7 +511,7 @@ IR_ALWAYS_INLINE void AggregateFunctions::SumDecimalAddOrSubtract(FunctionContex const DecimalVal& src, DecimalVal* dst, bool subtract) { if (src.is_null) return; if (dst->is_null) InitZero<DecimalVal>(ctx, dst); - int precision = Expr::GetConstantInt(*ctx, Expr::ARG_TYPE_PRECISION, 0); + int precision = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_PRECISION, 0); // Since the src and dst are guaranteed to be the same scale, we can just // do a simple add. int m = subtract ? -1 : 1; @@ -573,7 +573,7 @@ template<> void AggregateFunctions::Min(FunctionContext* ctx, const DecimalVal& src, DecimalVal* dst) { if (src.is_null) return; - int precision = Expr::GetConstantInt(*ctx, Expr::ARG_TYPE_PRECISION, 0); + int precision = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_PRECISION, 0); if (precision <= 9) { if (dst->is_null || src.val4 < dst->val4) *dst = src; } else if (precision <= 19) { @@ -587,7 +587,7 @@ template<> void AggregateFunctions::Max(FunctionContext* ctx, const DecimalVal& src, DecimalVal* dst) { if (src.is_null) return; - int precision = Expr::GetConstantInt(*ctx, Expr::ARG_TYPE_PRECISION, 0); + int precision = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_PRECISION, 0); if (precision <= 9) { if (dst->is_null || src.val4 > dst->val4) *dst = src; } else if (precision <= 19) { @@ -1200,8 +1200,8 @@ void AggregateFunctions::HllUpdate( if (src.is_null) return; DCHECK(!dst->is_null); DCHECK_EQ(dst->len, HLL_LEN); - uint64_t hash_value = AnyValUtil::HashDecimal64( - src, Expr::GetConstantInt(*ctx, Expr::ARG_TYPE_SIZE, 0), HashUtil::FNV64_SEED); + int byte_size = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 0); + uint64_t hash_value = AnyValUtil::HashDecimal64(src, byte_size, HashUtil::FNV64_SEED); if (hash_value != 0) { // Use the lower bits to index into the number of streams and then // find the first 1 bit after the index bits. http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/exprs/conditional-functions-ir.cc ---------------------------------------------------------------------- diff --git a/be/src/exprs/conditional-functions-ir.cc b/be/src/exprs/conditional-functions-ir.cc index ed98a87..662f167 100644 --- a/be/src/exprs/conditional-functions-ir.cc +++ b/be/src/exprs/conditional-functions-ir.cc @@ -23,11 +23,11 @@ using namespace impala; using namespace impala_udf; #define IS_NULL_COMPUTE_FUNCTION(type) \ - type IsNullExpr::Get##type(ExprContext* context, const TupleRow* row) { \ + type IsNullExpr::Get##type(ExprContext* ctx, const TupleRow* row) { \ DCHECK_EQ(children_.size(), 2); \ - type val = children_[0]->Get##type(context, row); \ + type val = children_[0]->Get##type(ctx, row); \ if (!val.is_null) return val; /* short-circuit */ \ - return children_[1]->Get##type(context, row); \ + return children_[1]->Get##type(ctx, row); \ } IS_NULL_COMPUTE_FUNCTION(BooleanVal); @@ -68,15 +68,14 @@ NULL_IF_COMPUTE_FUNCTION(TimestampVal); NULL_IF_COMPUTE_FUNCTION(DecimalVal); #define NULL_IF_ZERO_COMPUTE_FUNCTION(type) \ - type ConditionalFunctions::NullIfZero(FunctionContext* context, const type& val) { \ + type ConditionalFunctions::NullIfZero(FunctionContext* ctx, const type& val) { \ if (val.is_null || val.val == 0) return type::null(); \ return val; \ } -DecimalVal ConditionalFunctions::NullIfZero( - FunctionContext* context, const DecimalVal& val) { +DecimalVal ConditionalFunctions::NullIfZero(FunctionContext* ctx, const DecimalVal& val) { if (val.is_null) return DecimalVal::null(); - int type_byte_size = Expr::GetConstantInt(*context, Expr::RETURN_TYPE_SIZE); + int type_byte_size = ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_SIZE); switch (type_byte_size) { case 4: if (val.val4 == 0) return DecimalVal::null(); @@ -101,7 +100,7 @@ NULL_IF_ZERO_COMPUTE_FUNCTION(FloatVal); NULL_IF_ZERO_COMPUTE_FUNCTION(DoubleVal); #define ZERO_IF_NULL_COMPUTE_FUNCTION(type) \ - type ConditionalFunctions::ZeroIfNull(FunctionContext* context, const type& val) { \ + type ConditionalFunctions::ZeroIfNull(FunctionContext* ctx, const type& val) { \ if (val.is_null) return type(0); \ return val; \ } @@ -115,13 +114,13 @@ ZERO_IF_NULL_COMPUTE_FUNCTION(DoubleVal); ZERO_IF_NULL_COMPUTE_FUNCTION(DecimalVal); #define IF_COMPUTE_FUNCTION(type) \ - type IfExpr::Get##type(ExprContext* context, const TupleRow* row) { \ + type IfExpr::Get##type(ExprContext* ctx, const TupleRow* row) { \ DCHECK_EQ(children_.size(), 3); \ - BooleanVal cond = children_[0]->GetBooleanVal(context, row); \ + BooleanVal cond = children_[0]->GetBooleanVal(ctx, row); \ if (cond.is_null || !cond.val) { \ - return children_[2]->Get##type(context, row); \ + return children_[2]->Get##type(ctx, row); \ } \ - return children_[1]->Get##type(context, row); \ + return children_[1]->Get##type(ctx, row); \ } IF_COMPUTE_FUNCTION(BooleanVal); @@ -136,10 +135,10 @@ IF_COMPUTE_FUNCTION(TimestampVal); IF_COMPUTE_FUNCTION(DecimalVal); #define COALESCE_COMPUTE_FUNCTION(type) \ - type CoalesceExpr::Get##type(ExprContext* context, const TupleRow* row) { \ + type CoalesceExpr::Get##type(ExprContext* ctx, const TupleRow* row) { \ DCHECK_GE(children_.size(), 1); \ for (int i = 0; i < children_.size(); ++i) { \ - type val = children_[i]->Get##type(context, row); \ + type val = children_[i]->Get##type(ctx, row); \ if (!val.is_null) return val; \ } \ return type::null(); \ http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/exprs/decimal-functions-ir.cc ---------------------------------------------------------------------- diff --git a/be/src/exprs/decimal-functions-ir.cc b/be/src/exprs/decimal-functions-ir.cc index cdbd799..3262e27 100644 --- a/be/src/exprs/decimal-functions-ir.cc +++ b/be/src/exprs/decimal-functions-ir.cc @@ -28,17 +28,17 @@ namespace impala { -IntVal DecimalFunctions::Precision(FunctionContext* context, const DecimalVal& val) { - return IntVal(Expr::GetConstantInt(*context, Expr::ARG_TYPE_PRECISION, 0)); +IntVal DecimalFunctions::Precision(FunctionContext* ctx, const DecimalVal& val) { + return IntVal(ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_PRECISION, 0)); } -IntVal DecimalFunctions::Scale(FunctionContext* context, const DecimalVal& val) { - return IntVal(Expr::GetConstantInt(*context, Expr::ARG_TYPE_SCALE, 0)); +IntVal DecimalFunctions::Scale(FunctionContext* ctx, const DecimalVal& val) { + return IntVal(ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SCALE, 0)); } -DecimalVal DecimalFunctions::Abs(FunctionContext* context, const DecimalVal& val) { +DecimalVal DecimalFunctions::Abs(FunctionContext* ctx, const DecimalVal& val) { if (val.is_null) return DecimalVal::null(); - int type_byte_size = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SIZE, 0); + int type_byte_size = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 0); switch (type_byte_size) { case 4: return DecimalVal(abs(val.val4)); @@ -52,79 +52,83 @@ DecimalVal DecimalFunctions::Abs(FunctionContext* context, const DecimalVal& val } } -DecimalVal DecimalFunctions::Ceil(FunctionContext* context, const DecimalVal& val) { - return DecimalOperators::RoundDecimal(context, val, DecimalOperators::CEIL); +DecimalVal DecimalFunctions::Ceil(FunctionContext* ctx, const DecimalVal& val) { + return DecimalOperators::RoundDecimal(ctx, val, DecimalOperators::CEIL); } -DecimalVal DecimalFunctions::Floor(FunctionContext* context, const DecimalVal& val) { - return DecimalOperators::RoundDecimal(context, val, DecimalOperators::FLOOR); +DecimalVal DecimalFunctions::Floor(FunctionContext* ctx, const DecimalVal& val) { + return DecimalOperators::RoundDecimal(ctx, val, DecimalOperators::FLOOR); } -DecimalVal DecimalFunctions::Round(FunctionContext* context, const DecimalVal& val) { - return DecimalOperators::RoundDecimal(context, val, DecimalOperators::ROUND); +DecimalVal DecimalFunctions::Round(FunctionContext* ctx, const DecimalVal& val) { + return DecimalOperators::RoundDecimal(ctx, val, DecimalOperators::ROUND); } /// Always inline in IR module so that constants can be replaced. IR_ALWAYS_INLINE DecimalVal DecimalFunctions::RoundTo( - FunctionContext* context, const DecimalVal& val, int scale, + FunctionContext* ctx, const DecimalVal& val, int scale, DecimalOperators::DecimalRoundOp op) { - int val_precision = Expr::GetConstantInt(*context, Expr::ARG_TYPE_PRECISION, 0); - int val_scale = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SCALE, 0); - int return_precision = Expr::GetConstantInt(*context, Expr::RETURN_TYPE_PRECISION); - int return_scale = Expr::GetConstantInt(*context, Expr::RETURN_TYPE_SCALE); + int val_precision = + ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_PRECISION, 0); + int val_scale = + ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SCALE, 0); + int return_precision = + ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_PRECISION); + int return_scale = + ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_SCALE); if (scale < 0) { - return DecimalOperators::RoundDecimalNegativeScale(context, + return DecimalOperators::RoundDecimalNegativeScale(ctx, val, val_precision, val_scale, return_precision, return_scale, op, -scale); } else { - return DecimalOperators::RoundDecimal(context, + return DecimalOperators::RoundDecimal(ctx, val, val_precision, val_scale, return_precision, return_scale, op); } } DecimalVal DecimalFunctions::RoundTo( - FunctionContext* context, const DecimalVal& val, const TinyIntVal& scale) { + FunctionContext* ctx, const DecimalVal& val, const TinyIntVal& scale) { DCHECK(!scale.is_null); - return RoundTo(context, val, scale.val, DecimalOperators::ROUND); + return RoundTo(ctx, val, scale.val, DecimalOperators::ROUND); } DecimalVal DecimalFunctions::RoundTo( - FunctionContext* context, const DecimalVal& val, const SmallIntVal& scale) { + FunctionContext* ctx, const DecimalVal& val, const SmallIntVal& scale) { DCHECK(!scale.is_null); - return RoundTo(context, val, scale.val, DecimalOperators::ROUND); + return RoundTo(ctx, val, scale.val, DecimalOperators::ROUND); } DecimalVal DecimalFunctions::RoundTo( - FunctionContext* context, const DecimalVal& val, const IntVal& scale) { + FunctionContext* ctx, const DecimalVal& val, const IntVal& scale) { DCHECK(!scale.is_null); - return RoundTo(context, val, scale.val, DecimalOperators::ROUND); + return RoundTo(ctx, val, scale.val, DecimalOperators::ROUND); } DecimalVal DecimalFunctions::RoundTo( - FunctionContext* context, const DecimalVal& val, const BigIntVal& scale) { + FunctionContext* ctx, const DecimalVal& val, const BigIntVal& scale) { DCHECK(!scale.is_null); - return RoundTo(context, val, scale.val, DecimalOperators::ROUND); + return RoundTo(ctx, val, scale.val, DecimalOperators::ROUND); } -DecimalVal DecimalFunctions::Truncate(FunctionContext* context, const DecimalVal& val) { - return DecimalOperators::RoundDecimal(context, val, DecimalOperators::TRUNCATE); +DecimalVal DecimalFunctions::Truncate(FunctionContext* ctx, const DecimalVal& val) { + return DecimalOperators::RoundDecimal(ctx, val, DecimalOperators::TRUNCATE); } DecimalVal DecimalFunctions::TruncateTo( - FunctionContext* context, const DecimalVal& val, const TinyIntVal& scale) { + FunctionContext* ctx, const DecimalVal& val, const TinyIntVal& scale) { DCHECK(!scale.is_null); - return RoundTo(context, val, scale.val, DecimalOperators::TRUNCATE); + return RoundTo(ctx, val, scale.val, DecimalOperators::TRUNCATE); } DecimalVal DecimalFunctions::TruncateTo( - FunctionContext* context, const DecimalVal& val, const SmallIntVal& scale) { + FunctionContext* ctx, const DecimalVal& val, const SmallIntVal& scale) { DCHECK(!scale.is_null); - return RoundTo(context, val, scale.val, DecimalOperators::TRUNCATE); + return RoundTo(ctx, val, scale.val, DecimalOperators::TRUNCATE); } DecimalVal DecimalFunctions::TruncateTo( - FunctionContext* context, const DecimalVal& val, const IntVal& scale) { + FunctionContext* ctx, const DecimalVal& val, const IntVal& scale) { DCHECK(!scale.is_null); - return RoundTo(context, val, scale.val, DecimalOperators::TRUNCATE); + return RoundTo(ctx, val, scale.val, DecimalOperators::TRUNCATE); } DecimalVal DecimalFunctions::TruncateTo( - FunctionContext* context, const DecimalVal& val, const BigIntVal& scale) { + FunctionContext* ctx, const DecimalVal& val, const BigIntVal& scale) { DCHECK(!scale.is_null); - return RoundTo(context, val, scale.val, DecimalOperators::TRUNCATE); + return RoundTo(ctx, val, scale.val, DecimalOperators::TRUNCATE); } } http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/exprs/decimal-operators-ir.cc ---------------------------------------------------------------------- diff --git a/be/src/exprs/decimal-operators-ir.cc b/be/src/exprs/decimal-operators-ir.cc index 516acbc..4dfc8dc 100644 --- a/be/src/exprs/decimal-operators-ir.cc +++ b/be/src/exprs/decimal-operators-ir.cc @@ -34,32 +34,32 @@ namespace impala { -#define RETURN_IF_OVERFLOW(context, overflow) \ +#define RETURN_IF_OVERFLOW(ctx, overflow) \ do {\ if (UNLIKELY(overflow)) {\ - context->AddWarning("Expression overflowed, returning NULL");\ + ctx->AddWarning("Expression overflowed, returning NULL");\ return DecimalVal::null();\ }\ } while (false) // Inline in IR module so branches can be optimised out. IR_ALWAYS_INLINE DecimalVal DecimalOperators::IntToDecimalVal( - FunctionContext* context, int precision, int scale, int64_t val) { + FunctionContext* ctx, int precision, int scale, int64_t val) { bool overflow = false; switch (ColumnType::GetDecimalByteSize(precision)) { case 4: { Decimal4Value dv = Decimal4Value::FromInt(precision, scale, val, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(dv.value()); } case 8: { Decimal8Value dv = Decimal8Value::FromInt(precision, scale, val, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(dv.value()); } case 16: { Decimal16Value dv = Decimal16Value::FromInt(precision, scale, val, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(dv.value()); } default: @@ -70,25 +70,25 @@ IR_ALWAYS_INLINE DecimalVal DecimalOperators::IntToDecimalVal( // Inline in IR module so branches can be optimised out. IR_ALWAYS_INLINE DecimalVal DecimalOperators::FloatToDecimalVal( - FunctionContext* context, int precision, int scale, double val) { + FunctionContext* ctx, int precision, int scale, double val) { bool overflow = false; switch (ColumnType::GetDecimalByteSize(precision)) { case 4: { Decimal4Value dv = Decimal4Value::FromDouble(precision, scale, val, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(dv.value()); } case 8: { Decimal8Value dv = Decimal8Value::FromDouble(precision, scale, val, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(dv.value()); } case 16: { Decimal16Value dv = Decimal16Value::FromDouble(precision, scale, val, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(dv.value()); } default: @@ -105,28 +105,28 @@ IR_ALWAYS_INLINE DecimalVal DecimalOperators::FloatToDecimalVal( // When going from a smaller type to a larger type, we convert and then scale. // Inline these functions in IR module so branches can be optimised out. -IR_ALWAYS_INLINE DecimalVal DecimalOperators::ScaleDecimalValue(FunctionContext* context, +IR_ALWAYS_INLINE DecimalVal DecimalOperators::ScaleDecimalValue(FunctionContext* ctx, const Decimal4Value& val, int val_scale, int output_precision, int output_scale) { bool overflow = false; switch (ColumnType::GetDecimalByteSize(output_precision)) { case 4: { Decimal4Value scaled_val = val.ScaleTo( val_scale, output_scale, output_precision, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(scaled_val.value()); } case 8: { Decimal8Value val8 = ToDecimal8(val, &overflow); Decimal8Value scaled_val = val8.ScaleTo( val_scale, output_scale, output_precision, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(scaled_val.value()); } case 16: { Decimal16Value val16 = ToDecimal16(val, &overflow); Decimal16Value scaled_val = val16.ScaleTo( val_scale, output_scale, output_precision, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(scaled_val.value()); } default: @@ -135,7 +135,7 @@ IR_ALWAYS_INLINE DecimalVal DecimalOperators::ScaleDecimalValue(FunctionContext* } } -IR_ALWAYS_INLINE DecimalVal DecimalOperators::ScaleDecimalValue(FunctionContext* context, +IR_ALWAYS_INLINE DecimalVal DecimalOperators::ScaleDecimalValue(FunctionContext* ctx, const Decimal8Value& val, int val_scale, int output_precision, int output_scale) { bool overflow = false; switch (ColumnType::GetDecimalByteSize(output_precision)) { @@ -143,20 +143,20 @@ IR_ALWAYS_INLINE DecimalVal DecimalOperators::ScaleDecimalValue(FunctionContext* Decimal8Value scaled_val = val.ScaleTo( val_scale, output_scale, output_precision, &overflow); Decimal4Value val4 = ToDecimal4(scaled_val, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(val4.value()); } case 8: { Decimal8Value scaled_val = val.ScaleTo( val_scale, output_scale, output_precision, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(scaled_val.value()); } case 16: { Decimal16Value val16 = ToDecimal16(val, &overflow); Decimal16Value scaled_val = val16.ScaleTo( val_scale, output_scale, output_precision, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(scaled_val.value()); } default: @@ -165,7 +165,7 @@ IR_ALWAYS_INLINE DecimalVal DecimalOperators::ScaleDecimalValue(FunctionContext* } } -IR_ALWAYS_INLINE DecimalVal DecimalOperators::ScaleDecimalValue(FunctionContext* context, +IR_ALWAYS_INLINE DecimalVal DecimalOperators::ScaleDecimalValue(FunctionContext* ctx, const Decimal16Value& val, int val_scale, int output_precision, int output_scale) { bool overflow = false; switch (ColumnType::GetDecimalByteSize(output_precision)) { @@ -173,20 +173,20 @@ IR_ALWAYS_INLINE DecimalVal DecimalOperators::ScaleDecimalValue(FunctionContext* Decimal16Value scaled_val = val.ScaleTo( val_scale, output_scale, output_precision, &overflow); Decimal4Value val4 = ToDecimal4(scaled_val, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(val4.value()); } case 8: { Decimal16Value scaled_val = val.ScaleTo( val_scale, output_scale, output_precision, &overflow); Decimal8Value val8 = ToDecimal8(scaled_val, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(val8.value()); } case 16: { Decimal16Value scaled_val = val.ScaleTo( val_scale, output_scale, output_precision, &overflow); - RETURN_IF_OVERFLOW(context, overflow); + RETURN_IF_OVERFLOW(ctx, overflow); return DecimalVal(scaled_val.value()); } default: @@ -268,29 +268,31 @@ static inline Decimal16Value GetDecimal16Value( } #define CAST_INT_TO_DECIMAL(from_type) \ - DecimalVal DecimalOperators::CastToDecimalVal( \ - FunctionContext* context, const from_type& val) { \ + IR_ALWAYS_INLINE DecimalVal DecimalOperators::CastToDecimalVal( \ + FunctionContext* ctx, const from_type& val) { \ if (val.is_null) return DecimalVal::null(); \ - int precision = Expr::GetConstantInt(*context, Expr::RETURN_TYPE_PRECISION); \ - int scale = Expr::GetConstantInt(*context, Expr::RETURN_TYPE_SCALE); \ - return IntToDecimalVal(context, precision, scale, val.val); \ + int precision = \ + ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_PRECISION); \ + int scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_SCALE); \ + return IntToDecimalVal(ctx, precision, scale, val.val); \ } #define CAST_FLOAT_TO_DECIMAL(from_type) \ - DecimalVal DecimalOperators::CastToDecimalVal( \ - FunctionContext* context, const from_type& val) { \ + IR_ALWAYS_INLINE DecimalVal DecimalOperators::CastToDecimalVal( \ + FunctionContext* ctx, const from_type& val) { \ if (val.is_null) return DecimalVal::null(); \ - int precision = Expr::GetConstantInt(*context, Expr::RETURN_TYPE_PRECISION); \ - int scale = Expr::GetConstantInt(*context, Expr::RETURN_TYPE_SCALE); \ - return FloatToDecimalVal(context, precision, scale, val.val); \ + int precision = \ + ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_PRECISION); \ + int scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_SCALE); \ + return FloatToDecimalVal(ctx, precision, scale, val.val); \ } #define CAST_DECIMAL_TO_INT(to_type) \ - to_type DecimalOperators::CastTo##to_type( \ - FunctionContext* context, const DecimalVal& val) { \ + IR_ALWAYS_INLINE to_type DecimalOperators::CastTo##to_type( \ + FunctionContext* ctx, const DecimalVal& val) { \ if (val.is_null) return to_type::null(); \ - int scale = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SCALE, 0); \ - switch (Expr::GetConstantInt(*context, Expr::ARG_TYPE_SIZE, 0)) { \ + int scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SCALE, 0); \ + switch (ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 0)) { \ case 4: { \ Decimal4Value dv(val.val4); \ return to_type(dv.whole_part(scale)); \ @@ -310,11 +312,11 @@ static inline Decimal16Value GetDecimal16Value( } #define CAST_DECIMAL_TO_FLOAT(to_type) \ - to_type DecimalOperators::CastTo##to_type( \ - FunctionContext* context, const DecimalVal& val) { \ + IR_ALWAYS_INLINE to_type DecimalOperators::CastTo##to_type( \ + FunctionContext* ctx, const DecimalVal& val) { \ if (val.is_null) return to_type::null(); \ - int scale = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SCALE, 0); \ - switch (Expr::GetConstantInt(*context, Expr::ARG_TYPE_SIZE, 0)) { \ + int scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SCALE, 0); \ + switch (ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 0)) { \ case 4: { \ Decimal4Value dv(val.val4); \ return to_type(dv.ToDouble(scale)); \ @@ -349,7 +351,7 @@ CAST_DECIMAL_TO_FLOAT(DoubleVal) // Inline in IR module so branches can be optimised out. IR_ALWAYS_INLINE DecimalVal DecimalOperators::RoundDecimalNegativeScale( - FunctionContext* context, const DecimalVal& val, int val_precision, int val_scale, + FunctionContext* ctx, const DecimalVal& val, int val_precision, int val_scale, int output_precision, int output_scale, const DecimalRoundOp& op, int64_t rounding_scale) { DCHECK_GT(rounding_scale, 0); @@ -360,19 +362,19 @@ IR_ALWAYS_INLINE DecimalVal DecimalOperators::RoundDecimalNegativeScale( switch (ColumnType::GetDecimalByteSize(val_precision)) { case 4: { Decimal4Value val4(val.val4); - result = ScaleDecimalValue(context, val4, val_scale, output_precision, + result = ScaleDecimalValue(ctx, val4, val_scale, output_precision, output_scale); break; } case 8: { Decimal8Value val8(val.val8); - result = ScaleDecimalValue(context, val8, val_scale, output_precision, + result = ScaleDecimalValue(ctx, val8, val_scale, output_precision, output_scale); break; } case 16: { Decimal16Value val16(val.val16); - result = ScaleDecimalValue(context, val16, val_scale, output_precision, + result = ScaleDecimalValue(ctx, val16, val_scale, output_precision, output_scale); break; } @@ -410,7 +412,7 @@ IR_ALWAYS_INLINE DecimalVal DecimalOperators::RoundDecimalNegativeScale( // Need to check for overflow. This can't happen in the other cases since the // FE should have picked a high enough precision. if (DecimalUtil::MAX_UNSCALED_DECIMAL16 - abs(delta) < abs(val16.value())) { - context->AddWarning("Expression overflowed, returning NULL"); + ctx->AddWarning("Expression overflowed, returning NULL"); return DecimalVal::null(); } result.val16 += delta; @@ -424,7 +426,7 @@ IR_ALWAYS_INLINE DecimalVal DecimalOperators::RoundDecimalNegativeScale( } // Inline in IR module so branches can be optimised out. -IR_ALWAYS_INLINE DecimalVal DecimalOperators::RoundDecimal(FunctionContext* context, +IR_ALWAYS_INLINE DecimalVal DecimalOperators::RoundDecimal(FunctionContext* ctx, const DecimalVal& val, int val_precision, int val_scale, int output_precision, int output_scale, const DecimalRoundOp& op) { if (val.is_null) return DecimalVal::null(); @@ -434,23 +436,22 @@ IR_ALWAYS_INLINE DecimalVal DecimalOperators::RoundDecimal(FunctionContext* cont switch (ColumnType::GetDecimalByteSize(val_precision)) { case 4: { Decimal4Value val4(val.val4); - result = ScaleDecimalValue(context, val4, val_scale, output_precision, + result = ScaleDecimalValue(ctx, val4, val_scale, output_precision, output_scale); delta = RoundDelta(val4, val_scale, output_scale, op); break; } case 8: { Decimal8Value val8(val.val8); - result = ScaleDecimalValue(context, val8, val_scale, output_precision, + result = ScaleDecimalValue(ctx, val8, val_scale, output_precision, output_scale); delta = RoundDelta(val8, val_scale, output_scale, op); break; } case 16: { Decimal16Value val16(val.val16); - result = ScaleDecimalValue(context, val16, val_scale, output_precision, + result = ScaleDecimalValue(ctx, val16, val_scale, output_precision, output_scale); - delta = RoundDelta(val16, val_scale, output_scale, op); break; } @@ -466,41 +467,49 @@ IR_ALWAYS_INLINE DecimalVal DecimalOperators::RoundDecimal(FunctionContext* cont // done the cast. if (delta == 0) return result; - - // The value in 'result' is before the rounding has occurred. - // This can't overflow. Rounding to a non-negative scale means at least one digit is - // dropped if rounding occurred and the round can add at most one digit before the - // decimal. + // The value in 'result' is before any rounding has occurred. If there is any rounding, + // the ouput's scale must be less than the input's scale. + DCHECK_GT(val_scale, output_scale); result.val16 += delta; + + // Rounding to a non-negative scale means at least one digit is dropped if rounding + // occurred and the round can add at most one digit before the decimal. This cannot + // overflow if output_precision >= val_precision. Otherwise, result can overflow. + bool overflow = output_precision < val_precision && + abs(result.val16) >= DecimalUtil::GetScaleMultiplier<int128_t>(output_precision); + RETURN_IF_OVERFLOW(ctx, overflow); return result; } // Inline in IR module so branches can be optimised out. IR_ALWAYS_INLINE DecimalVal DecimalOperators::RoundDecimal( - FunctionContext* context, const DecimalVal& val, const DecimalRoundOp& op) { - int val_precision = Expr::GetConstantInt(*context, Expr::ARG_TYPE_PRECISION, 0); - int val_scale = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SCALE, 0); - int return_precision = Expr::GetConstantInt(*context, Expr::RETURN_TYPE_PRECISION); - int return_scale = Expr::GetConstantInt(*context, Expr::RETURN_TYPE_SCALE); - return RoundDecimal(context, val, val_precision, val_scale, return_precision, + FunctionContext* ctx, const DecimalVal& val, const DecimalRoundOp& op) { + int val_precision = + ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_PRECISION, 0); + int val_scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SCALE, 0); + int return_precision = + ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_PRECISION); + int return_scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_SCALE); + return RoundDecimal(ctx, val, val_precision, val_scale, return_precision, return_scale, op); } -// Cast is just RoundDecimal(TRUNCATE). -// TODO: how we handle cast to a smaller scale is an implementation detail in the spec. -// We could also choose to cast by doing ROUND. -DecimalVal DecimalOperators::CastToDecimalVal( - FunctionContext* context, const DecimalVal& val) { - return RoundDecimal(context, val, TRUNCATE); +// If query option decimal_v2 is true, cast is RoundDecimal(ROUND). +// Otherwise, it's RoundDecimal(TRUNCATE). +IR_ALWAYS_INLINE DecimalVal DecimalOperators::CastToDecimalVal( + FunctionContext* ctx, const DecimalVal& val) { + int is_decimal_v2 = ctx->impl()->GetConstFnAttr(FunctionContextImpl::DECIMAL_V2); + DCHECK(is_decimal_v2 == 0 || is_decimal_v2 == 1); + return RoundDecimal(ctx, val, is_decimal_v2 != 0 ? ROUND : TRUNCATE); } -DecimalVal DecimalOperators::CastToDecimalVal( - FunctionContext* context, const StringVal& val) { +IR_ALWAYS_INLINE DecimalVal DecimalOperators::CastToDecimalVal( + FunctionContext* ctx, const StringVal& val) { if (val.is_null) return DecimalVal::null(); StringParser::ParseResult result; DecimalVal dv; - int precision = Expr::GetConstantInt(*context, Expr::RETURN_TYPE_PRECISION); - int scale = Expr::GetConstantInt(*context, Expr::RETURN_TYPE_SCALE); + int precision = ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_PRECISION); + int scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_SCALE); switch (ColumnType::GetDecimalByteSize(precision)) { case 4: { Decimal4Value dv4 = StringParser::StringToDecimal<int32_t>( @@ -534,10 +543,10 @@ DecimalVal DecimalOperators::CastToDecimalVal( } StringVal DecimalOperators::CastToStringVal( - FunctionContext* context, const DecimalVal& val) { + FunctionContext* ctx, const DecimalVal& val) { if (val.is_null) return StringVal::null(); - int precision = Expr::GetConstantInt(*context, Expr::ARG_TYPE_PRECISION, 0); - int scale = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SCALE, 0); + int precision = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_PRECISION, 0); + int scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SCALE, 0); string s; switch (ColumnType::GetDecimalByteSize(precision)) { case 4: @@ -553,7 +562,7 @@ StringVal DecimalOperators::CastToStringVal( DCHECK(false); return StringVal::null(); } - StringVal result(context, s.size()); + StringVal result(ctx, s.size()); memcpy(result.ptr, s.c_str(), s.size()); return result; } @@ -574,10 +583,10 @@ IR_ALWAYS_INLINE T DecimalOperators::ConvertToNanoseconds(T val, int scale) { } TimestampVal DecimalOperators::CastToTimestampVal( - FunctionContext* context, const DecimalVal& val) { + FunctionContext* ctx, const DecimalVal& val) { if (val.is_null) return TimestampVal::null(); - int precision = Expr::GetConstantInt(*context, Expr::ARG_TYPE_PRECISION, 0); - int scale = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SCALE, 0); + int precision = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_PRECISION, 0); + int scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SCALE, 0); TimestampVal result; switch (ColumnType::GetDecimalByteSize(precision)) { case 4: { @@ -620,9 +629,9 @@ TimestampVal DecimalOperators::CastToTimestampVal( } BooleanVal DecimalOperators::CastToBooleanVal( - FunctionContext* context, const DecimalVal& val) { + FunctionContext* ctx, const DecimalVal& val) { if (val.is_null) return BooleanVal::null(); - switch (Expr::GetConstantInt(*context, Expr::ARG_TYPE_SIZE, 0)) { + switch (ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 0)) { case 4: return BooleanVal(val.val4 != 0); case 8: @@ -637,17 +646,18 @@ BooleanVal DecimalOperators::CastToBooleanVal( #define DECIMAL_ARITHMETIC_OP(FN_NAME, OP_FN) \ DecimalVal DecimalOperators::FN_NAME( \ - FunctionContext* context, const DecimalVal& x, const DecimalVal& y) { \ + FunctionContext* ctx, const DecimalVal& x, const DecimalVal& y) { \ if (x.is_null || y.is_null) return DecimalVal::null(); \ bool overflow = false; \ - int x_size = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SIZE, 0); \ - int x_scale = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SCALE, 0); \ - int y_size = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SIZE, 1); \ - int y_scale = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SCALE, 1); \ + int x_size = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 0); \ + int x_scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SCALE, 0); \ + int y_size = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 1); \ + int y_scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SCALE, 1); \ int return_precision = \ - Expr::GetConstantInt(*context, Expr::RETURN_TYPE_PRECISION); \ - int return_scale = Expr::GetConstantInt(*context, Expr::RETURN_TYPE_SCALE); \ - switch (Expr::GetConstantInt(*context, Expr::RETURN_TYPE_SIZE)) { \ + ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_PRECISION); \ + int return_scale = \ + ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_SCALE); \ + switch (ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_SIZE)) { \ case 4: { \ Decimal4Value x_val = GetDecimal4Value(x, x_size, &overflow); \ Decimal4Value y_val = GetDecimal4Value(y, y_size, &overflow); \ @@ -669,7 +679,7 @@ BooleanVal DecimalOperators::CastToBooleanVal( Decimal16Value y_val = GetDecimal16Value(y, y_size, &overflow); \ Decimal16Value result = x_val.OP_FN<int128_t>(x_scale, y_val, y_scale, \ return_precision, return_scale, &overflow); \ - RETURN_IF_OVERFLOW(context, overflow); \ + RETURN_IF_OVERFLOW(ctx, overflow); \ return DecimalVal(result.value()); \ } \ default: \ @@ -680,18 +690,19 @@ BooleanVal DecimalOperators::CastToBooleanVal( #define DECIMAL_ARITHMETIC_OP_CHECK_NAN(FN_NAME, OP_FN) \ DecimalVal DecimalOperators::FN_NAME( \ - FunctionContext* context, const DecimalVal& x, const DecimalVal& y) { \ + FunctionContext* ctx, const DecimalVal& x, const DecimalVal& y) { \ if (x.is_null || y.is_null) return DecimalVal::null(); \ bool overflow = false; \ bool is_nan = false; \ - int x_size = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SIZE, 0); \ - int x_scale = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SCALE, 0); \ - int y_size = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SIZE, 1); \ - int y_scale = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SCALE, 1); \ + int x_size = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 0); \ + int x_scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SCALE, 0); \ + int y_size = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 1); \ + int y_scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SCALE, 1); \ int return_precision = \ - Expr::GetConstantInt(*context, Expr::RETURN_TYPE_PRECISION); \ - int return_scale = Expr::GetConstantInt(*context, Expr::RETURN_TYPE_SCALE); \ - switch (Expr::GetConstantInt(*context, Expr::RETURN_TYPE_SIZE)) { \ + ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_PRECISION); \ + int return_scale = \ + ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_SCALE); \ + switch (ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_SIZE)) { \ case 4: { \ Decimal4Value x_val = GetDecimal4Value(x, x_size, &overflow); \ Decimal4Value y_val = GetDecimal4Value(y, y_size, &overflow); \ @@ -715,7 +726,7 @@ BooleanVal DecimalOperators::CastToBooleanVal( Decimal16Value y_val = GetDecimal16Value(y, y_size, &overflow); \ Decimal16Value result = x_val.OP_FN<int128_t>(x_scale, y_val, y_scale, \ return_precision, return_scale, &is_nan, &overflow); \ - RETURN_IF_OVERFLOW(context, overflow); \ + RETURN_IF_OVERFLOW(ctx, overflow); \ if (is_nan) return DecimalVal::null(); \ return DecimalVal(result.value()); \ } \ @@ -727,10 +738,10 @@ BooleanVal DecimalOperators::CastToBooleanVal( #define DECIMAL_BINARY_OP_NONNULL(OP_FN, X, Y) \ bool dummy = false; \ - int x_size = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SIZE, 0); \ - int x_scale = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SCALE, 0); \ - int y_size = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SIZE, 1); \ - int y_scale = Expr::GetConstantInt(*context, Expr::ARG_TYPE_SCALE, 1); \ + int x_size = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 0); \ + int x_scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SCALE, 0); \ + int y_size = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 1); \ + int y_scale = ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SCALE, 1); \ int byte_size = ::max(x_size, y_size); \ switch (byte_size) { \ case 4: { \ @@ -759,14 +770,14 @@ BooleanVal DecimalOperators::CastToBooleanVal( #define DECIMAL_BINARY_OP(FN_NAME, OP_FN) \ BooleanVal DecimalOperators::FN_NAME( \ - FunctionContext* context, const DecimalVal& x, const DecimalVal& y) { \ + FunctionContext* ctx, const DecimalVal& x, const DecimalVal& y) { \ if (x.is_null || y.is_null) return BooleanVal::null(); \ DECIMAL_BINARY_OP_NONNULL(OP_FN, x, y) \ } #define NULLSAFE_DECIMAL_BINARY_OP(FN_NAME, OP_FN, IS_EQUAL) \ BooleanVal DecimalOperators::FN_NAME( \ - FunctionContext* context, const DecimalVal& x, const DecimalVal& y) { \ + FunctionContext* ctx, const DecimalVal& x, const DecimalVal& y) { \ if (x.is_null) return BooleanVal(IS_EQUAL ? y.is_null : !y.is_null); \ if (y.is_null) return BooleanVal(!IS_EQUAL); \ DECIMAL_BINARY_OP_NONNULL(OP_FN, x, y) \ http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/exprs/decimal-operators.h ---------------------------------------------------------------------- diff --git a/be/src/exprs/decimal-operators.h b/be/src/exprs/decimal-operators.h index 9f425b0..196ffdb 100644 --- a/be/src/exprs/decimal-operators.h +++ b/be/src/exprs/decimal-operators.h @@ -104,7 +104,7 @@ class DecimalOperators { }; /// Evaluates a round from 'val' and returns the result, using the rounding rule of - /// 'type'. + /// 'op. Returns DecimalVal::null() on overflow. static DecimalVal RoundDecimal(FunctionContext* context, const DecimalVal& val, int val_precision, int val_scale, int output_precision, int output_scale, const DecimalRoundOp& op); http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/exprs/expr-codegen-test.cc ---------------------------------------------------------------------- diff --git a/be/src/exprs/expr-codegen-test.cc b/be/src/exprs/expr-codegen-test.cc index 0c484f4..c73c89e 100644 --- a/be/src/exprs/expr-codegen-test.cc +++ b/be/src/exprs/expr-codegen-test.cc @@ -16,30 +16,39 @@ // under the License. // The following is cross-compiled to native code and IR, and used in the test below - +#include "exprs/decimal-operators.h" #include "exprs/expr.h" #include "udf/udf.h" using namespace impala; using namespace impala_udf; -// TestGetConstant() fills in the following constants -struct Constants { +// TestGetTypeAttrs() fills in the following constants +struct FnAttr { int return_type_size; int arg0_type_size; int arg1_type_size; int arg2_type_size; }; -IntVal TestGetConstant( +#ifdef IR_COMPILE +#include "exprs/decimal-operators-ir.cc" +#endif + +DecimalVal TestGetFnAttrs( FunctionContext* ctx, const DecimalVal& arg0, StringVal arg1, StringVal arg2) { - Constants* state = reinterpret_cast<Constants*>( + FnAttr* state = reinterpret_cast<FnAttr*>( ctx->GetFunctionState(FunctionContext::THREAD_LOCAL)); - state->return_type_size = Expr::GetConstantInt(*ctx, Expr::RETURN_TYPE_SIZE); - state->arg0_type_size = Expr::GetConstantInt(*ctx, Expr::ARG_TYPE_SIZE, 0); - state->arg1_type_size = Expr::GetConstantInt(*ctx, Expr::ARG_TYPE_SIZE, 1); - state->arg2_type_size = Expr::GetConstantInt(*ctx, Expr::ARG_TYPE_SIZE, 2); - return IntVal(10); + state->return_type_size = + ctx->impl()->GetConstFnAttr(FunctionContextImpl::RETURN_TYPE_SIZE); + state->arg0_type_size = + ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 0); + state->arg1_type_size = + ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 1); + state->arg2_type_size = + ctx->impl()->GetConstFnAttr(FunctionContextImpl::ARG_TYPE_SIZE, 2); + // This function and its callees call FunctionContextImpl::GetConstFnAttr(); + return DecimalOperators::CastToDecimalVal(ctx, arg0); } // Don't compile the actual test to IR @@ -48,10 +57,12 @@ IntVal TestGetConstant( #include "testutil/gtest-util.h" #include "codegen/llvm-codegen.h" #include "common/init.h" +#include "exprs/anyval-util.h" #include "exprs/expr-context.h" #include "runtime/exec-env.h" #include "runtime/mem-tracker.h" #include "runtime/runtime-state.h" +#include "runtime/test-env.h" #include "service/fe-support.h" #include "udf/udf-internal.h" #include "udf/udf-test-harness.h" @@ -64,28 +75,48 @@ using namespace llvm; namespace impala { -const char* TEST_GET_CONSTANT_SYMBOL = - "_Z15TestGetConstantPN10impala_udf15FunctionContextERKNS_10DecimalValENS_9StringValES5_"; +const char* TEST_GET_FN_ATTR_SYMBOL = + "_Z14TestGetFnAttrsPN10impala_udf15FunctionContextERKNS_10DecimalValENS_9StringValES5_"; const int ARG0_PRECISION = 10; const int ARG0_SCALE = 2; const int ARG1_LEN = 1; +const int RET_PRECISION = 10; +const int RET_SCALE = 1; class ExprCodegenTest : public ::testing::Test { protected: - int InlineConstants(Expr* expr, LlvmCodeGen* codegen, llvm::Function* fn) { - return expr->InlineConstants(codegen, fn); + scoped_ptr<TestEnv> test_env_; + RuntimeState* runtime_state_; + FunctionContext* fn_ctx_; + FnAttr fn_type_attr_; + + int InlineConstFnAttrs(Expr* expr, LlvmCodeGen* codegen, llvm::Function* fn) { + FunctionContext::TypeDesc ret_type = AnyValUtil::ColumnTypeToTypeDesc(expr->type()); + vector<FunctionContext::TypeDesc> arg_types; + for (const Expr* child : expr->children()) { + arg_types.push_back(AnyValUtil::ColumnTypeToTypeDesc(child->type())); + } + return codegen->InlineConstFnAttrs(ret_type, arg_types, fn); } - static Status CreateFromFile( - ObjectPool* pool, const string& filename, scoped_ptr<LlvmCodeGen>* codegen) { - RETURN_IF_ERROR(LlvmCodeGen::CreateFromFile(pool, NULL, filename, "test", codegen)); + Status CreateFromFile(const string& filename, scoped_ptr<LlvmCodeGen>* codegen) { + RETURN_IF_ERROR(LlvmCodeGen::CreateFromFile(runtime_state_, + runtime_state_->obj_pool(), NULL, filename, "test", codegen)); return (*codegen)->MaterializeModule(); } virtual void SetUp() { + TQueryOptions query_options; + query_options.__set_decimal_v2(true); + test_env_.reset(new TestEnv()); + EXPECT_OK(test_env_->CreateQueryState(0, 1, 8 * 1024 * 1024, &query_options, + &runtime_state_)); + FunctionContext::TypeDesc return_type; - return_type.type = FunctionContext::TYPE_INT; + return_type.type = FunctionContext::TYPE_DECIMAL; + return_type.precision = RET_PRECISION; + return_type.scale = RET_SCALE; FunctionContext::TypeDesc arg0_type; arg0_type.type = FunctionContext::TYPE_DECIMAL; @@ -104,23 +135,25 @@ class ExprCodegenTest : public ::testing::Test { arg_types.push_back(arg1_type); arg_types.push_back(arg2_type); - fn_ctx_ = UdfTestHarness::CreateTestContext(return_type, arg_types); + fn_ctx_ = UdfTestHarness::CreateTestContext(return_type, arg_types, runtime_state_); // Initialize fn_ctx_ with constants - memset(&constants_, -1, sizeof(Constants)); - fn_ctx_->SetFunctionState(FunctionContext::THREAD_LOCAL, &constants_); + memset(&fn_type_attr_, -1, sizeof(FnAttr)); + fn_ctx_->SetFunctionState(FunctionContext::THREAD_LOCAL, &fn_type_attr_); } virtual void TearDown() { fn_ctx_->impl()->Close(); delete fn_ctx_; + runtime_state_ = NULL; + test_env_.reset(); } - void CheckConstants() { - EXPECT_EQ(constants_.return_type_size, 4); - EXPECT_EQ(constants_.arg0_type_size, 8); - EXPECT_EQ(constants_.arg1_type_size, ARG1_LEN); - EXPECT_EQ(constants_.arg2_type_size, 0); // varlen + void CheckFnAttr() { + EXPECT_EQ(fn_type_attr_.return_type_size, 8); + EXPECT_EQ(fn_type_attr_.arg0_type_size, 8); + EXPECT_EQ(fn_type_attr_.arg1_type_size, ARG1_LEN); + EXPECT_EQ(fn_type_attr_.arg2_type_size, 0); // varlen } static bool VerifyFunction(LlvmCodeGen* codegen, llvm::Function* fn) { @@ -130,9 +163,6 @@ class ExprCodegenTest : public ::testing::Test { static void ResetVerification(LlvmCodeGen* codegen) { codegen->ResetVerification(); } - - FunctionContext* fn_ctx_; - Constants constants_; }; TExprNode CreateDecimalLiteral(int precision, int scale) { @@ -183,10 +213,12 @@ TExprNode CreateStringLiteral(int len = -1) { return expr; } -// Creates a function call to TestGetConstant() in test-udfs.h -TExprNode CreateFunctionCall(vector<TExprNode> children) { +// Creates a function call to TestGetFnAttrs() in test-udfs.h +TExprNode CreateFunctionCall(vector<TExprNode> children, int precision, int scale) { TScalarType scalar_type; - scalar_type.type = TPrimitiveType::INT; + scalar_type.type = TPrimitiveType::DECIMAL; + scalar_type.__set_precision(precision); + scalar_type.__set_scale(scale); TTypeNode type; type.type = TTypeNodeType::SCALAR; @@ -196,10 +228,10 @@ TExprNode CreateFunctionCall(vector<TExprNode> children) { col_type.__set_types(vector<TTypeNode>(1, type)); TFunctionName fn_name; - fn_name.function_name = "test_get_constant"; + fn_name.function_name = "test_get_type_attr"; TScalarFunction scalar_fn; - scalar_fn.symbol = TEST_GET_CONSTANT_SYMBOL; + scalar_fn.symbol = TEST_GET_FN_ATTR_SYMBOL; TFunction fn; fn.name = fn_name; @@ -219,18 +251,22 @@ TExprNode CreateFunctionCall(vector<TExprNode> children) { return expr; } -TEST_F(ExprCodegenTest, TestGetConstantInterpreted) { - DecimalVal arg0_val; +TEST_F(ExprCodegenTest, TestGetConstFnAttrsInterpreted) { + // Call fn and check results'. The input is of type Decimal(10,2) (i.e. 10000.25) and + // the output type is Decimal(10,1) (i.e. 10000.3). The precision and scale of arguments + // and return types are encoded above (ARG0_*, RET_*); + int64_t v = 1000025; + DecimalVal arg0_val(v); StringVal arg1_val; StringVal arg2_val; - IntVal result = TestGetConstant(fn_ctx_, arg0_val, arg1_val, arg2_val); + DecimalVal result = TestGetFnAttrs(fn_ctx_, arg0_val, arg1_val, arg2_val); // sanity check result EXPECT_EQ(result.is_null, false); - EXPECT_EQ(result.val, 10); - CheckConstants(); + EXPECT_EQ(result.val8, 100003); + CheckFnAttr(); } -TEST_F(ExprCodegenTest, TestInlineConstants) { +TEST_F(ExprCodegenTest, TestInlineConstFnAttrs) { // Setup thrift descriptors TExprNode arg0 = CreateDecimalLiteral(ARG0_PRECISION, ARG0_SCALE); TExprNode arg1 = CreateStringLiteral(ARG1_LEN); @@ -241,7 +277,7 @@ TEST_F(ExprCodegenTest, TestInlineConstants) { exprs.push_back(arg1); exprs.push_back(arg2); - TExprNode fn_call = CreateFunctionCall(exprs); + TExprNode fn_call = CreateFunctionCall(exprs, RET_PRECISION, RET_SCALE); exprs.insert(exprs.begin(), fn_call); TExpr texpr; @@ -253,21 +289,21 @@ TEST_F(ExprCodegenTest, TestInlineConstants) { ExprContext* ctx; ASSERT_OK(Expr::CreateExprTree(&pool, texpr, &ctx)); - // Get TestGetConstant() IR function + // Get TestGetFnAttrs() IR function stringstream test_udf_file; test_udf_file << getenv("IMPALA_HOME") << "/be/build/latest/exprs/expr-codegen-test.ll"; scoped_ptr<LlvmCodeGen> codegen; - ASSERT_OK(ExprCodegenTest::CreateFromFile(&pool, test_udf_file.str(), &codegen)); - Function* fn = codegen->GetFunction(TEST_GET_CONSTANT_SYMBOL, false); + ASSERT_OK(CreateFromFile(test_udf_file.str(), &codegen)); + Function* fn = codegen->GetFunction(TEST_GET_FN_ATTR_SYMBOL, false); ASSERT_TRUE(fn != NULL); - // Function verification should fail because we haven't inlined GetConstant() calls + // Function verification should fail because we haven't inlined GetTypeAttr() calls bool verification_succeeded = VerifyFunction(codegen.get(), fn); EXPECT_FALSE(verification_succeeded); - // Call InlineConstants() and rerun verification - int replaced = InlineConstants(ctx->root(), codegen.get(), fn); - EXPECT_EQ(replaced, 4); + // Call InlineConstFnAttrs() and rerun verification + int replaced = InlineConstFnAttrs(ctx->root(), codegen.get(), fn); + EXPECT_EQ(replaced, 9); ResetVerification(codegen.get()); verification_succeeded = VerifyFunction(codegen.get(), fn); EXPECT_TRUE(verification_succeeded) << LlvmCodeGen::Print(fn); @@ -277,17 +313,19 @@ TEST_F(ExprCodegenTest, TestInlineConstants) { ASSERT_TRUE(fn != NULL); void* fn_ptr; codegen->AddFunctionToJit(fn, &fn_ptr); - ASSERT_OK(codegen->FinalizeModule()); - LOG(ERROR) << "Optimized fn: " << LlvmCodeGen::Print(fn); - - // Call fn and check results - DecimalVal arg0_val; - typedef IntVal (*TestGetConstantType)(FunctionContext*, const DecimalVal&); - IntVal result = reinterpret_cast<TestGetConstantType>(fn_ptr)(fn_ctx_, arg0_val); + EXPECT_TRUE(codegen->FinalizeModule().ok()) << LlvmCodeGen::Print(fn); + + // Call fn and check results'. The input is of type Decimal(10,2) (i.e. 10000.25) and + // the output type is Decimal(10,1) (i.e. 10000.3). The precision and scale of arguments + // and return types are encoded above (ARG0_*, RET_*); + int64_t v = 1000025; + DecimalVal arg0_val(v); + typedef DecimalVal (*TestGetFnAttrs)(FunctionContext*, const DecimalVal&); + DecimalVal result = reinterpret_cast<TestGetFnAttrs>(fn_ptr)(fn_ctx_, arg0_val); // sanity check result EXPECT_EQ(result.is_null, false); - EXPECT_EQ(result.val, 10); - CheckConstants(); + EXPECT_EQ(result.val8, 100003); + CheckFnAttr(); } } http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/exprs/expr-test.cc ---------------------------------------------------------------------- diff --git a/be/src/exprs/expr-test.cc b/be/src/exprs/expr-test.cc index cad716e..1b2dc51 100644 --- a/be/src/exprs/expr-test.cc +++ b/be/src/exprs/expr-test.cc @@ -1335,6 +1335,16 @@ DecimalTestCase decimal_cases[] = { {{ false, 1, 32, 1 }, { false, 1, 32, 1 }} }, { "mod(cast('-1.23' as decimal(32,2)), cast('1.0' as decimal(32,2)))", {{ false, -23, 32, 2 }, { false, -23, 32, 2 }} }, + { "cast(cast(0.12344 as decimal(6,5)) as decimal(6,4))", + {{ false, 1234, 6, 4 }, { false, 1234, 6, 4 }} }, + { "cast(cast(0.12345 as decimal(6,5)) as decimal(6,4))", + {{ false, 1234, 6, 4 }, { false, 1235, 6, 4 }} }, + { "cast(cast('0.999' as decimal(4,3)) as decimal(1,0))", + {{ false, 0, 1, 0 }, { false, 1, 1, 0 }} }, + { "cast(cast(999999999.99 as DECIMAL(11,2)) as DECIMAL(9,0))", + {{ false, 999999999, 9, 0 }, { true, 0, 9, 0 }} }, + { "cast(cast(-999999999.99 as DECIMAL(11,2)) as DECIMAL(9,0))", + {{ false, -999999999, 9, 0 }, { true, 0, 9, 0 }} }, { "mod(cast(NULL as decimal(2,0)), cast('10' as decimal(2,0)))", {{ true, 0, 2, 0 }, { true, 0, 2, 0 }} }, { "mod(cast('10' as decimal(2,0)), cast(NULL as decimal(2,0)))", @@ -1357,7 +1367,7 @@ TEST_F(ExprTest, DecimalArithmeticExprs) { string opt = "DECIMAL_V2=" + lexical_cast<string>(v2); executor_->pushExecOption(opt); for (const DecimalTestCase& c : decimal_cases) { - const DecimalExpectedResult& r = c.expected[0]; + const DecimalExpectedResult& r = c.expected[v2]; const ColumnType& type = ColumnType::CreateDecimalType(r.precision, r.scale); if (r.null) { TestIsNull(c.expr, type); http://git-wip-us.apache.org/repos/asf/incubator-impala/blob/f982c3f7/be/src/exprs/expr.cc ---------------------------------------------------------------------- diff --git a/be/src/exprs/expr.cc b/be/src/exprs/expr.cc index f119a9d..32c16a0 100644 --- a/be/src/exprs/expr.cc +++ b/be/src/exprs/expr.cc @@ -79,8 +79,6 @@ namespace impala { const char* Expr::LLVM_CLASS_NAME = "class.impala::Expr"; -const char* Expr::GET_CONSTANT_INT_SYMBOL_PREFIX = "_ZN6impala4Expr14GetConstantInt"; - template<class T> bool ParseString(const string& str, T* val) { istringstream stream(str); @@ -589,90 +587,6 @@ Status Expr::GetConstVal( return GetFnContextError(context); } -int Expr::GetConstantInt(const FunctionContext::TypeDesc& return_type, - const std::vector<FunctionContext::TypeDesc>& arg_types, ExprConstant c, int i) { - switch (c) { - case RETURN_TYPE_SIZE: - DCHECK_EQ(i, -1); - return AnyValUtil::TypeDescToColumnType(return_type).GetByteSize(); - case RETURN_TYPE_PRECISION: - DCHECK_EQ(i, -1); - DCHECK_EQ(return_type.type, FunctionContext::TYPE_DECIMAL); - return return_type.precision; - case RETURN_TYPE_SCALE: - DCHECK_EQ(i, -1); - DCHECK_EQ(return_type.type, FunctionContext::TYPE_DECIMAL); - return return_type.scale; - case ARG_TYPE_SIZE: - DCHECK_GE(i, 0); - DCHECK_LT(i, arg_types.size()); - return AnyValUtil::TypeDescToColumnType(arg_types[i]).GetByteSize(); - case ARG_TYPE_PRECISION: - DCHECK_GE(i, 0); - DCHECK_LT(i, arg_types.size()); - DCHECK_EQ(arg_types[i].type, FunctionContext::TYPE_DECIMAL); - return arg_types[i].precision; - case ARG_TYPE_SCALE: - DCHECK_GE(i, 0); - DCHECK_LT(i, arg_types.size()); - DCHECK_EQ(arg_types[i].type, FunctionContext::TYPE_DECIMAL); - return arg_types[i].scale; - default: - CHECK(false) << "NYI"; - return -1; - } -} - -int Expr::GetConstantInt(const FunctionContext& ctx, ExprConstant c, int i) { - return GetConstantInt(ctx.GetReturnType(), ctx.impl()->arg_types(), c, i); -} - -int Expr::InlineConstants(LlvmCodeGen* codegen, Function* fn) { - FunctionContext::TypeDesc return_type = AnyValUtil::ColumnTypeToTypeDesc(type_); - vector<FunctionContext::TypeDesc> arg_types; - for (int i = 0; i < children_.size(); ++i) { - arg_types.push_back(AnyValUtil::ColumnTypeToTypeDesc(children_[i]->type_)); - } - return InlineConstants(return_type, arg_types, codegen, fn); -} - -int Expr::InlineConstants(const FunctionContext::TypeDesc& return_type, - const std::vector<FunctionContext::TypeDesc>& arg_types, LlvmCodeGen* codegen, - Function* fn) { - int replaced = 0; - for (inst_iterator iter = inst_begin(fn), end = inst_end(fn); iter != end;) { - // Increment iter now so we don't mess it up modifying the instruction below - Instruction* instr = &*(iter++); - - // Look for call instructions - if (!isa<CallInst>(instr)) continue; - CallInst* call_instr = cast<CallInst>(instr); - Function* called_fn = call_instr->getCalledFunction(); - - // Look for call to Expr::GetConstant*() - if (called_fn == NULL || - called_fn->getName().find(GET_CONSTANT_INT_SYMBOL_PREFIX) == string::npos) { - continue; - } - - // 'c' and 'i' arguments must be constant - ConstantInt* c_arg = dyn_cast<ConstantInt>(call_instr->getArgOperand(1)); - ConstantInt* i_arg = dyn_cast<ConstantInt>(call_instr->getArgOperand(2)); - DCHECK(c_arg != NULL) << "Non-constant 'c' argument to Expr::GetConstant*()"; - DCHECK(i_arg != NULL) << "Non-constant 'i' argument to Expr::GetConstant*()"; - - // Replace the called function with the appropriate constant - ExprConstant c_val = static_cast<ExprConstant>(c_arg->getSExtValue()); - int i_val = static_cast<int>(i_arg->getSExtValue()); - // All supported constants are currently integers. - call_instr->replaceAllUsesWith(ConstantInt::get(codegen->GetType(TYPE_INT), - GetConstantInt(return_type, arg_types, c_val, i_val))); - call_instr->eraseFromParent(); - ++replaced; - } - return replaced; -} - Status Expr::GetCodegendComputeFnWrapper(LlvmCodeGen* codegen, Function** fn) { if (ir_compute_fn_ != NULL) { *fn = ir_compute_fn_;
