This is an automated email from the ASF dual-hosted git repository.
ravindra pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 31aa19d ARROW-4206: [Gandiva] support decimal divide and mod
31aa19d is described below
commit 31aa19da25352d5f6abdf3264d57306c3b52bb05
Author: Pindikura Ravindra <[email protected]>
AuthorDate: Thu Mar 14 11:29:42 2019 +0530
ARROW-4206: [Gandiva] support decimal divide and mod
Author: Pindikura Ravindra <[email protected]>
Closes #3813 from pravindra/gdivmod and squashes the following commits:
96ef4054 <Pindikura Ravindra> ARROW-4206: add global symbol for new fns
a9ad13fe <Pindikura Ravindra> ARROW-4206: Add more tests/comments
697c2343 <Pindikura Ravindra> ARROW-4206: Fix build errors
267f117e <Pindikura Ravindra> ARROW-4206: support decimal divide and mod
---
cpp/src/arrow/util/basic_decimal.h | 3 +
cpp/src/arrow/util/decimal-test.cc | 115 +++++++++-
cpp/src/arrow/util/decimal.h | 5 +
cpp/src/gandiva/decimal_ir.cc | 58 +++++
cpp/src/gandiva/decimal_ir.h | 4 +
cpp/src/gandiva/decimal_xlarge.cc | 121 ++++++++--
cpp/src/gandiva/decimal_xlarge.h | 8 +
cpp/src/gandiva/function_registry_arithmetic.cc | 2 +
cpp/src/gandiva/precompiled/CMakeLists.txt | 7 +-
cpp/src/gandiva/precompiled/decimal_ops.cc | 84 +++++++
cpp/src/gandiva/precompiled/decimal_ops.h | 10 +
cpp/src/gandiva/precompiled/decimal_ops_test.cc | 281 ++++++++++++++++++++----
cpp/src/gandiva/precompiled/decimal_wrapper.cc | 34 +++
cpp/src/gandiva/tests/decimal_single_test.cc | 33 ++-
14 files changed, 698 insertions(+), 67 deletions(-)
diff --git a/cpp/src/arrow/util/basic_decimal.h
b/cpp/src/arrow/util/basic_decimal.h
index 7929b11..2e5857c 100644
--- a/cpp/src/arrow/util/basic_decimal.h
+++ b/cpp/src/arrow/util/basic_decimal.h
@@ -138,6 +138,9 @@ class ARROW_EXPORT BasicDecimal128 {
/// - If 'round' is false, the right-most digits are simply dropped.
BasicDecimal128 ReduceScaleBy(int32_t reduce_by, bool round = true) const;
+ // returns 1 for positive and zero decimal values, -1 for negative decimal
values.
+ inline int64_t Sign() const { return 1 | (high_bits_ >> 63); }
+
/// \brief count the number of leading binary zeroes.
int32_t CountLeadingBinaryZeros() const;
diff --git a/cpp/src/arrow/util/decimal-test.cc
b/cpp/src/arrow/util/decimal-test.cc
index db4d35f..4ba7d7f 100644
--- a/cpp/src/arrow/util/decimal-test.cc
+++ b/cpp/src/arrow/util/decimal-test.cc
@@ -23,12 +23,16 @@
#include <tuple>
#include <gtest/gtest.h>
+#include <boost/multiprecision/cpp_int.hpp>
#include "arrow/status.h"
#include "arrow/testing/gtest_util.h"
+#include "arrow/testing/random.h"
#include "arrow/util/decimal.h"
#include "arrow/util/macros.h"
+using boost::multiprecision::int128_t;
+
namespace arrow {
class DecimalTestFixture : public ::testing::Test {
@@ -466,20 +470,113 @@ TEST(Decimal128Test, TestToInteger) {
ASSERT_RAISES(Invalid, invalid_int64.ToInteger(&out2));
}
+template <typename ArrowType, typename CType = typename ArrowType::c_type>
+std::vector<CType> GetRandomNumbers(int32_t size) {
+ auto rand = random::RandomArrayGenerator(0x5487655);
+ auto x_array = rand.Numeric<ArrowType>(size, 0,
std::numeric_limits<CType>::max(), 0);
+
+ auto x_ptr = x_array->data()->template GetValues<CType>(1);
+ std::vector<CType> ret;
+ for (int i = 0; i < size; ++i) {
+ ret.push_back(x_ptr[i]);
+ }
+ return ret;
+}
+
TEST(Decimal128Test, Multiply) {
- Decimal128 result;
+ ASSERT_EQ(Decimal128(60501), Decimal128(301) * Decimal128(201));
+
+ ASSERT_EQ(Decimal128(-60501), Decimal128(-301) * Decimal128(201));
+
+ ASSERT_EQ(Decimal128(-60501), Decimal128(301) * Decimal128(-201));
+
+ ASSERT_EQ(Decimal128(60501), Decimal128(-301) * Decimal128(-201));
+
+ // Test some random numbers.
+ for (auto x : GetRandomNumbers<Int32Type>(16)) {
+ for (auto y : GetRandomNumbers<Int32Type>(16)) {
+ Decimal128 result = Decimal128(x) * Decimal128(y);
+ ASSERT_EQ(Decimal128(static_cast<int64_t>(x) * y), result)
+ << " x: " << x << " y: " << y;
+ }
+ }
+
+ // Test some edge cases
+ for (auto x : std::vector<int128_t>{-INT64_MAX, -INT32_MAX, 0, INT32_MAX,
INT64_MAX}) {
+ for (auto y :
+ std::vector<int128_t>{-INT32_MAX, -32, -2, -1, 0, 1, 2, 32,
INT32_MAX}) {
+ Decimal128 result = Decimal128(x.str()) * Decimal128(y.str());
+ ASSERT_EQ(Decimal128((x * y).str()), result) << " x: " << x << " y: " <<
y;
+ }
+ }
+}
+
+TEST(Decimal128Test, Divide) {
+ ASSERT_EQ(Decimal128(66), Decimal128(20100) / Decimal128(301));
+
+ ASSERT_EQ(Decimal128(-66), Decimal128(-20100) / Decimal128(301));
+
+ ASSERT_EQ(Decimal128(-66), Decimal128(20100) / Decimal128(-301));
- result = Decimal128("301") * Decimal128("201");
- ASSERT_EQ(result.ToIntegerString(), "60501");
+ ASSERT_EQ(Decimal128(66), Decimal128(-20100) / Decimal128(-301));
- result = Decimal128("-301") * Decimal128("201");
- ASSERT_EQ(result.ToIntegerString(), "-60501");
+ // Test some random numbers.
+ for (auto x : GetRandomNumbers<Int32Type>(16)) {
+ for (auto y : GetRandomNumbers<Int32Type>(16)) {
+ if (y == 0) {
+ continue;
+ }
- result = Decimal128("301") * Decimal128("-201");
- ASSERT_EQ(result.ToIntegerString(), "-60501");
+ Decimal128 result = Decimal128(x) / Decimal128(y);
+ ASSERT_EQ(Decimal128(static_cast<int64_t>(x) / y), result)
+ << " x: " << x << " y: " << y;
+ }
+ }
+
+ // Test some edge cases
+ for (auto x : std::vector<int128_t>{-INT64_MAX, -INT32_MAX, 0, INT32_MAX,
INT64_MAX}) {
+ for (auto y : std::vector<int128_t>{-INT32_MAX, -32, -2, -1, 1, 2, 32,
INT32_MAX}) {
+ Decimal128 result = Decimal128(x.str()) * Decimal128(y.str());
+ ASSERT_EQ(Decimal128((x * y).str()), result) << " x: " << x << " y: " <<
y;
+ }
+ }
+}
+
+TEST(Decimal128Test, Mod) {
+ ASSERT_EQ(Decimal128(234), Decimal128(20100) % Decimal128(301));
+
+ ASSERT_EQ(Decimal128(-234), Decimal128(-20100) % Decimal128(301));
+
+ ASSERT_EQ(Decimal128(234), Decimal128(20100) % Decimal128(-301));
+
+ ASSERT_EQ(Decimal128(-234), Decimal128(-20100) % Decimal128(-301));
+
+ // Test some random numbers.
+ for (auto x : GetRandomNumbers<Int32Type>(16)) {
+ for (auto y : GetRandomNumbers<Int32Type>(16)) {
+ if (y == 0) {
+ continue;
+ }
+
+ Decimal128 result = Decimal128(x) % Decimal128(y);
+ ASSERT_EQ(Decimal128(static_cast<int64_t>(x) % y), result)
+ << " x: " << x << " y: " << y;
+ }
+ }
+
+ // Test some edge cases
+ for (auto x : std::vector<int128_t>{-INT64_MAX, -INT32_MAX, 0, INT32_MAX,
INT64_MAX}) {
+ for (auto y : std::vector<int128_t>{-INT32_MAX, -32, -2, -1, 1, 2, 32,
INT32_MAX}) {
+ Decimal128 result = Decimal128(x.str()) * Decimal128(y.str());
+ ASSERT_EQ(Decimal128((x * y).str()), result) << " x: " << x << " y: " <<
y;
+ }
+ }
+}
- result = Decimal128("-301") * Decimal128("-201");
- ASSERT_EQ(result.ToIntegerString(), "60501");
+TEST(Decimal128Test, Sign) {
+ ASSERT_EQ(1, Decimal128(999999).Sign());
+ ASSERT_EQ(-1, Decimal128(-999999).Sign());
+ ASSERT_EQ(1, Decimal128(0).Sign());
}
TEST(Decimal128Test, GetWholeAndFraction) {
diff --git a/cpp/src/arrow/util/decimal.h b/cpp/src/arrow/util/decimal.h
index 4c61a17..3cb86d1 100644
--- a/cpp/src/arrow/util/decimal.h
+++ b/cpp/src/arrow/util/decimal.h
@@ -123,6 +123,11 @@ class ARROW_EXPORT Decimal128 : public BasicDecimal128 {
return Status::OK();
}
+ friend std::ostream& operator<<(std::ostream& os, const Decimal128& decimal)
{
+ os << decimal.ToIntegerString();
+ return os;
+ }
+
private:
/// Converts internal error code to Status
Status ToArrowStatus(DecimalStatus dstatus) const;
diff --git a/cpp/src/gandiva/decimal_ir.cc b/cpp/src/gandiva/decimal_ir.cc
index 53727bb..47e60cf 100644
--- a/cpp/src/gandiva/decimal_ir.cc
+++ b/cpp/src/gandiva/decimal_ir.cc
@@ -406,6 +406,60 @@ Status DecimalIR::BuildMultiply() {
return Status::OK();
}
+Status DecimalIR::BuildDivideOrMod(const std::string& function_name,
+ const std::string& internal_fname) {
+ // Create fn prototype :
+ // int128_t
+ // divide_decimal128_decimal128(int64_t execution_context,
+ // int128_t x_value, int32_t x_precision,
int32_t x_scale,
+ // int128_t y_value, int32_t y_precision,
int32_t y_scale
+ // int32_t out_precision, int32_t out_scale)
+ auto i32 = types()->i32_type();
+ auto i128 = types()->i128_type();
+ auto function = BuildFunction(function_name, i128,
+ {
+ {"execution_context", types()->i64_type()},
+ {"x_value", i128},
+ {"x_precision", i32},
+ {"x_scale", i32},
+ {"y_value", i128},
+ {"y_precision", i32},
+ {"y_scale", i32},
+ {"out_precision", i32},
+ {"out_scale", i32},
+ });
+
+ auto arg_iter = function->arg_begin();
+ auto execution_context = &arg_iter[0];
+ ValueFull x(&arg_iter[1], &arg_iter[2], &arg_iter[3]);
+ ValueFull y(&arg_iter[4], &arg_iter[5], &arg_iter[6]);
+ ValueFull out(nullptr, &arg_iter[7], &arg_iter[8]);
+
+ auto entry = llvm::BasicBlock::Create(*context(), "entry", function);
+ ir_builder()->SetInsertPoint(entry);
+
+ // Make call to pre-compiled IR function.
+ auto block = ir_builder()->GetInsertBlock();
+ auto out_high_ptr = new llvm::AllocaInst(types()->i64_type(), 0, "out_hi",
block);
+ auto out_low_ptr = new llvm::AllocaInst(types()->i64_type(), 0, "out_low",
block);
+ auto x_split = ValueSplit::MakeFromInt128(this, x.value());
+ auto y_split = ValueSplit::MakeFromInt128(this, y.value());
+
+ std::vector<llvm::Value*> args = {
+ execution_context, x_split.high(), x_split.low(), x.precision(),
x.scale(),
+ y_split.high(), y_split.low(), y.precision(), y.scale(),
out.precision(),
+ out.scale(), out_high_ptr, out_low_ptr,
+ };
+ ir_builder()->CreateCall(module()->getFunction(internal_fname), args);
+
+ auto out_high = ir_builder()->CreateLoad(out_high_ptr);
+ auto out_low = ir_builder()->CreateLoad(out_low_ptr);
+ auto result = ValueSplit(out_high, out_low).AsInt128(this);
+
+ ir_builder()->CreateRet(result);
+ return Status::OK();
+}
+
Status DecimalIR::AddFunctions(Engine* engine) {
auto decimal_ir = std::make_shared<DecimalIR>(engine);
@@ -418,6 +472,10 @@ Status DecimalIR::AddFunctions(Engine* engine) {
ARROW_RETURN_NOT_OK(decimal_ir->BuildAdd());
ARROW_RETURN_NOT_OK(decimal_ir->BuildSubtract());
ARROW_RETURN_NOT_OK(decimal_ir->BuildMultiply());
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildDivideOrMod(
+ "divide_decimal128_decimal128",
"divide_internal_decimal128_decimal128"));
+ ARROW_RETURN_NOT_OK(decimal_ir->BuildDivideOrMod("mod_decimal128_decimal128",
+
"mod_internal_decimal128_decimal128"));
return Status::OK();
}
diff --git a/cpp/src/gandiva/decimal_ir.h b/cpp/src/gandiva/decimal_ir.h
index e552cf1..048b9d3 100644
--- a/cpp/src/gandiva/decimal_ir.h
+++ b/cpp/src/gandiva/decimal_ir.h
@@ -149,6 +149,10 @@ class DecimalIR : public FunctionIRBuilder {
// Build the function for decimal multiplication.
Status BuildMultiply();
+ // Build the function for decimal division/mod.
+ Status BuildDivideOrMod(const std::string& function_name,
+ const std::string& internal_name);
+
// Add a trace in IR code.
void AddTrace(const std::string& fmt, std::vector<llvm::Value*> args);
diff --git a/cpp/src/gandiva/decimal_xlarge.cc
b/cpp/src/gandiva/decimal_xlarge.cc
index 570cd88..4a8f3e5 100644
--- a/cpp/src/gandiva/decimal_xlarge.cc
+++ b/cpp/src/gandiva/decimal_xlarge.cc
@@ -55,7 +55,35 @@ void ExportedDecimalFunctions::AddMappings(Engine* engine)
const {
engine->AddGlobalMappingForFunc(
"gdv_xlarge_multiply_and_scale_down", types->void_type()
/*return_type*/, args,
reinterpret_cast<void*>(gdv_xlarge_multiply_and_scale_down));
+
+ // gdv_xlarge_scale_up_and_divide
+ args = {types->i64_type(), // int64_t x_high
+ types->i64_type(), // uint64_t x_low
+ types->i64_type(), // int64_t y_high
+ types->i64_type(), // uint64_t y_low
+ types->i32_type(), // int32_t increase_scale_by
+ types->i64_ptr_type(), // int64_t* out_high
+ types->i64_ptr_type(), // uint64_t* out_low
+ types->i8_ptr_type()}; // bool* overflow
+
+ engine->AddGlobalMappingForFunc(
+ "gdv_xlarge_scale_up_and_divide", types->void_type() /*return_type*/,
args,
+ reinterpret_cast<void*>(gdv_xlarge_scale_up_and_divide));
+
+ // gdv_xlarge_mod
+ args = {types->i64_type(), // int64_t x_high
+ types->i64_type(), // uint64_t x_low
+ types->i32_type(), // int32_t x_scale
+ types->i64_type(), // int64_t y_high
+ types->i64_type(), // uint64_t y_low
+ types->i32_type(), // int32_t y_scale
+ types->i64_ptr_type(), // int64_t* out_high
+ types->i64_ptr_type()}; // uint64_t* out_low
+
+ engine->AddGlobalMappingForFunc("gdv_xlarge_mod", types->void_type()
/*return_type*/,
+ args,
reinterpret_cast<void*>(gdv_xlarge_mod));
}
+
} // namespace gandiva
#endif // !GANDIVA_UNIT_TEST
@@ -103,27 +131,34 @@ static BasicDecimal128 ConvertToDecimal128(int256_t in,
bool* overflow) {
return is_negative ? -result : result;
}
+static constexpr int32_t kMaxLargeScale = 2 * DecimalTypeUtil::kMaxPrecision;
+
+// Compute the scale multipliers once.
+static std::array<int256_t, kMaxLargeScale + 1> kLargeScaleMultipliers =
+ ([]() -> std::array<int256_t, kMaxLargeScale + 1> {
+ std::array<int256_t, kMaxLargeScale + 1> values;
+ values[0] = 1;
+ for (int32_t idx = 1; idx <= kMaxLargeScale; idx++) {
+ values[idx] = values[idx - 1] * 10;
+ }
+ return values;
+ })();
+
+static int256_t GetScaleMultiplier(int scale) {
+ DCHECK_GE(scale, 0);
+ DCHECK_LE(scale, kMaxLargeScale);
+
+ return kLargeScaleMultipliers[scale];
+}
+
// divide input by 10^reduce_by, and round up the fractional part.
static int256_t ReduceScaleBy(int256_t in, int32_t reduce_by) {
- DCHECK_GE(reduce_by, 0);
- DCHECK_LE(reduce_by, 2 * DecimalTypeUtil::kMaxPrecision);
-
if (reduce_by == 0) {
// nothing to do.
return in;
}
- int256_t divisor;
- if (reduce_by <= DecimalTypeUtil::kMaxPrecision) {
- divisor = ConvertToInt256(BasicDecimal128::GetScaleMultiplier(reduce_by));
- } else {
- divisor = ConvertToInt256(
- BasicDecimal128::GetScaleMultiplier(DecimalTypeUtil::kMaxPrecision));
- for (auto i = DecimalTypeUtil::kMaxPrecision; i < reduce_by; i++) {
- divisor *= 10;
- }
- }
-
+ int256_t divisor = GetScaleMultiplier(reduce_by);
DCHECK_GT(divisor, 0);
DCHECK_EQ(divisor % 2, 0); // multiple of 10.
auto result = in / divisor;
@@ -135,6 +170,14 @@ static int256_t ReduceScaleBy(int256_t in, int32_t
reduce_by) {
return result;
}
+// multiply input by 10^increase_by.
+static int256_t IncreaseScaleBy(int256_t in, int32_t increase_by) {
+ DCHECK_GE(increase_by, 0);
+ DCHECK_LE(increase_by, 2 * DecimalTypeUtil::kMaxPrecision);
+
+ return in * GetScaleMultiplier(increase_by);
+}
+
} // namespace internal
} // namespace gandiva
@@ -155,4 +198,54 @@ void gdv_xlarge_multiply_and_scale_down(int64_t x_high,
uint64_t x_low, int64_t
*out_low = result.low_bits();
}
+void gdv_xlarge_scale_up_and_divide(int64_t x_high, uint64_t x_low, int64_t
y_high,
+ uint64_t y_low, int32_t increase_scale_by,
+ int64_t* out_high, uint64_t* out_low,
+ bool* overflow) {
+ BasicDecimal128 x{x_high, x_low};
+ BasicDecimal128 y{y_high, y_low};
+
+ int256_t x_large = gandiva::internal::ConvertToInt256(x);
+ int256_t x_large_scaled_up =
+ gandiva::internal::IncreaseScaleBy(x_large, increase_scale_by);
+ int256_t y_large = gandiva::internal::ConvertToInt256(y);
+ int256_t result_large = x_large_scaled_up / y_large;
+ int256_t remainder_large = x_large_scaled_up % y_large;
+
+ // Since we are scaling up and then, scaling down, round-up the result (+1
for +ve,
+ // -1 for -ve), if the remainder is >= 2 * divisor.
+ if (abs(2 * remainder_large) >= abs(y_large)) {
+ // x +ve and y +ve, result is +ve => (1 ^ 1) + 1 = 0 + 1 = +1
+ // x +ve and y -ve, result is -ve => (-1 ^ 1) + 1 = -2 + 1 = -1
+ // x +ve and y -ve, result is -ve => (1 ^ -1) + 1 = -2 + 1 = -1
+ // x -ve and y -ve, result is +ve => (-1 ^ -1) + 1 = 0 + 1 = +1
+ result_large += (x.Sign() ^ y.Sign()) + 1;
+ }
+ auto result = gandiva::internal::ConvertToDecimal128(result_large, overflow);
+ *out_high = result.high_bits();
+ *out_low = result.low_bits();
+}
+
+void gdv_xlarge_mod(int64_t x_high, uint64_t x_low, int32_t x_scale, int64_t
y_high,
+ uint64_t y_low, int32_t y_scale, int64_t* out_high,
+ uint64_t* out_low) {
+ BasicDecimal128 x{x_high, x_low};
+ BasicDecimal128 y{y_high, y_low};
+
+ int256_t x_large = gandiva::internal::ConvertToInt256(x);
+ int256_t y_large = gandiva::internal::ConvertToInt256(y);
+ if (x_scale < y_scale) {
+ x_large = gandiva::internal::IncreaseScaleBy(x_large, y_scale - x_scale);
+ } else {
+ y_large = gandiva::internal::IncreaseScaleBy(y_large, x_scale - y_scale);
+ }
+ auto intermediate_result = x_large % y_large;
+ bool overflow = false;
+ auto result = gandiva::internal::ConvertToDecimal128(intermediate_result,
&overflow);
+ DCHECK_EQ(overflow, false);
+
+ *out_high = result.high_bits();
+ *out_low = result.low_bits();
+}
+
} // extern "C"
diff --git a/cpp/src/gandiva/decimal_xlarge.h b/cpp/src/gandiva/decimal_xlarge.h
index 9d48937..c2e2dd8 100644
--- a/cpp/src/gandiva/decimal_xlarge.h
+++ b/cpp/src/gandiva/decimal_xlarge.h
@@ -27,4 +27,12 @@ void gdv_xlarge_multiply_and_scale_down(int64_t x_high,
uint64_t x_low, int64_t
uint64_t y_low, int32_t
reduce_scale_by,
int64_t* out_high, uint64_t* out_low,
bool* overflow);
+
+void gdv_xlarge_scale_up_and_divide(int64_t x_high, uint64_t x_low, int64_t
y_high,
+ uint64_t y_low, int32_t increase_scale_by,
+ int64_t* out_high, uint64_t* out_low,
bool* overflow);
+
+void gdv_xlarge_mod(int64_t x_high, uint64_t x_low, int32_t x_scale, int64_t
y_high,
+ uint64_t y_low, int32_t y_scale, int64_t* out_high,
+ uint64_t* out_low);
}
diff --git a/cpp/src/gandiva/function_registry_arithmetic.cc
b/cpp/src/gandiva/function_registry_arithmetic.cc
index 921f91c..ad8445b 100644
--- a/cpp/src/gandiva/function_registry_arithmetic.cc
+++ b/cpp/src/gandiva/function_registry_arithmetic.cc
@@ -57,6 +57,8 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(add, decimal128),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(subtract, decimal128),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(multiply, decimal128),
+ BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(divide, decimal128),
+ BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, decimal128),
BINARY_RELATIONAL_BOOL_FN(equal),
BINARY_RELATIONAL_BOOL_FN(not_equal),
diff --git a/cpp/src/gandiva/precompiled/CMakeLists.txt
b/cpp/src/gandiva/precompiled/CMakeLists.txt
index 3ad0e09..b2c3017 100644
--- a/cpp/src/gandiva/precompiled/CMakeLists.txt
+++ b/cpp/src/gandiva/precompiled/CMakeLists.txt
@@ -128,6 +128,9 @@ if(ARROW_BUILD_TESTS)
add_precompiled_unit_test(arithmetic_ops_test.cc arithmetic_ops.cc
../context_helper.cc)
add_precompiled_unit_test(extended_math_ops_test.cc extended_math_ops.cc
../context_helper.cc)
- add_precompiled_unit_test(decimal_ops_test.cc decimal_ops.cc
../decimal_type_util.cc
- ../decimal_xlarge.cc)
+ add_precompiled_unit_test(decimal_ops_test.cc
+ decimal_ops.cc
+ ../decimal_type_util.cc
+ ../decimal_xlarge.cc
+ ../context_helper.cc)
endif()
diff --git a/cpp/src/gandiva/precompiled/decimal_ops.cc
b/cpp/src/gandiva/precompiled/decimal_ops.cc
index 9aa1f41..e13a5d8 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops.cc
+++ b/cpp/src/gandiva/precompiled/decimal_ops.cc
@@ -23,8 +23,14 @@
#include "gandiva/decimal_type_util.h"
#include "gandiva/decimal_xlarge.h"
+#include "gandiva/gdv_function_stubs.h"
#include "gandiva/logging.h"
+// Several operations (multiply, divide, mod, ..) require converting to
256-bit, and we
+// use the boost library for doing 256-bit operations. To avoid references to
boost from
+// the precompiled-to-ir code (this causes issues with symbol resolution at
runtime), we
+// use a wrapper exported from the CPP code. The wrapper functions are named
gdv_xlarge_xx
+
namespace gandiva {
namespace decimalops {
@@ -339,5 +345,83 @@ BasicDecimal128 Multiply(const BasicDecimalScalar128& x,
const BasicDecimalScala
return result;
}
+BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t out_precision,
+ int32_t out_scale, bool* overflow) {
+ if (y.value() == 0) {
+ char const* err_msg = "divide by zero error";
+ gdv_fn_context_set_error_msg(context, err_msg);
+ return 0;
+ }
+
+ // scale upto the output scale, and do an integer division.
+ int32_t delta_scale = out_scale + y.scale() - x.scale();
+ DCHECK_GE(delta_scale, 0);
+
+ BasicDecimal128 result;
+ auto num_bits_required_after_scaling = MaxBitsRequiredAfterScaling(x,
delta_scale);
+ if (num_bits_required_after_scaling <= 127) {
+ // fast-path. The dividend fits in 128-bit after scaling too.
+ *overflow = false;
+
+ // do the division.
+ auto x_scaled = CheckAndIncreaseScale(x.value(), delta_scale);
+ BasicDecimal128 remainder;
+ auto status = x_scaled.Divide(y.value(), &result, &remainder);
+ DCHECK_EQ(status, arrow::DecimalStatus::kSuccess);
+
+ // round-up
+ if (BasicDecimal128::Abs(2 * remainder) >=
BasicDecimal128::Abs(y.value())) {
+ result += (x.value().Sign() ^ y.value().Sign()) + 1;
+ }
+ } else {
+ // convert to 256-bit and do the divide.
+ *overflow = delta_scale > 38 && num_bits_required_after_scaling > 255;
+ if (!*overflow) {
+ int64_t result_high;
+ uint64_t result_low;
+
+ gdv_xlarge_scale_up_and_divide(x.value().high_bits(),
x.value().low_bits(),
+ y.value().high_bits(),
y.value().low_bits(),
+ delta_scale, &result_high, &result_low,
overflow);
+ result = BasicDecimal128(result_high, result_low);
+ }
+ }
+ return result;
+}
+
+BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t out_precision,
+ int32_t out_scale, bool* overflow) {
+ if (y.value() == 0) {
+ char const* err_msg = "divide by zero error";
+ gdv_fn_context_set_error_msg(context, err_msg);
+ return 0;
+ }
+
+ // Adsjust x and y to the same scale (higher one), and then, do a integer
mod.
+ *overflow = false;
+ BasicDecimal128 result;
+ int32_t min_lz = MinLeadingZeros(x, y);
+ if (min_lz >= 2) {
+ auto higher_scale = std::max(x.scale(), y.scale());
+ auto x_scaled = CheckAndIncreaseScale(x.value(), higher_scale - x.scale());
+ auto y_scaled = CheckAndIncreaseScale(y.value(), higher_scale - y.scale());
+ result = x_scaled % y_scaled;
+ DCHECK_LE(BasicDecimal128::Abs(result), BasicDecimal128::GetMaxValue());
+ } else {
+ int64_t result_high;
+ uint64_t result_low;
+
+ gdv_xlarge_mod(x.value().high_bits(), x.value().low_bits(), x.scale(),
+ y.value().high_bits(), y.value().low_bits(), y.scale(),
&result_high,
+ &result_low);
+ result = BasicDecimal128(result_high, result_low);
+ }
+ DCHECK(BasicDecimal128::Abs(result) <= BasicDecimal128::Abs(x.value()) ||
+ BasicDecimal128::Abs(result) <= BasicDecimal128::Abs(y.value()));
+ return result;
+}
+
} // namespace decimalops
} // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_ops.h
b/cpp/src/gandiva/precompiled/decimal_ops.h
index f45bc78..e0aea7e 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops.h
+++ b/cpp/src/gandiva/precompiled/decimal_ops.h
@@ -40,5 +40,15 @@ arrow::BasicDecimal128 Multiply(const BasicDecimalScalar128&
x,
const BasicDecimalScalar128& y, int32_t
out_precision,
int32_t out_scale, bool* overflow);
+/// Divide 'x' by 'y', and return the result.
+arrow::BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t
out_precision,
+ int32_t out_scale, bool* overflow);
+
+/// Divide 'x' by 'y', and return the remainder.
+arrow::BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
+ const BasicDecimalScalar128& y, int32_t
out_precision,
+ int32_t out_scale, bool* overflow);
+
} // namespace decimalops
} // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_ops_test.cc
b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
index f6d0b02..9672a25 100644
--- a/cpp/src/gandiva/precompiled/decimal_ops_test.cc
+++ b/cpp/src/gandiva/precompiled/decimal_ops_test.cc
@@ -22,17 +22,27 @@
#include "arrow/testing/gtest_util.h"
#include "gandiva/decimal_scalar.h"
#include "gandiva/decimal_type_util.h"
+#include "gandiva/execution_context.h"
#include "gandiva/precompiled/decimal_ops.h"
#include "gandiva/precompiled/types.h"
namespace gandiva {
+const arrow::Decimal128 kThirtyFive9s(std::string(35, '9'));
+const arrow::Decimal128 kThirtySix9s(std::string(36, '9'));
+const arrow::Decimal128 kThirtyEight9s(std::string(38, '9'));
+
class TestDecimalSql : public ::testing::Test {
protected:
static void Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x,
const DecimalScalar128& y, const DecimalScalar128&
expected_result,
bool expected_overflow);
+ static void VerifyAllSign(DecimalTypeUtil::Op op, const DecimalScalar128&
left,
+ const DecimalScalar128& right,
+ const DecimalScalar128& expected_output,
+ bool expected_overflow);
+
void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
const DecimalScalar128& expected_result) {
// TODO: overflow checks
@@ -53,7 +63,34 @@ class TestDecimalSql : public ::testing::Test {
void MultiplyAndVerifyAllSign(const DecimalScalar128& x, const
DecimalScalar128& y,
const DecimalScalar128& expected_result,
- bool expected_overflow);
+ bool expected_overflow) {
+ return VerifyAllSign(DecimalTypeUtil::kOpMultiply, x, y, expected_result,
+ expected_overflow);
+ }
+
+ void DivideAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected_result, bool
expected_overflow) {
+ return Verify(DecimalTypeUtil::kOpDivide, x, y, expected_result,
expected_overflow);
+ }
+
+ void DivideAndVerifyAllSign(const DecimalScalar128& x, const
DecimalScalar128& y,
+ const DecimalScalar128& expected_result,
+ bool expected_overflow) {
+ return VerifyAllSign(DecimalTypeUtil::kOpDivide, x, y, expected_result,
+ expected_overflow);
+ }
+
+ void ModAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected_result, bool
expected_overflow) {
+ return Verify(DecimalTypeUtil::kOpMod, x, y, expected_result,
expected_overflow);
+ }
+
+ void ModAndVerifyAllSign(const DecimalScalar128& x, const DecimalScalar128&
y,
+ const DecimalScalar128& expected_result,
+ bool expected_overflow) {
+ return VerifyAllSign(DecimalTypeUtil::kOpMod, x, y, expected_result,
+ expected_overflow);
+ }
};
#define EXPECT_DECIMAL_EQ(op, x, y, expected_result, expected_overflow,
actual_result, \
@@ -78,6 +115,7 @@ void TestDecimalSql::Verify(DecimalTypeUtil::Op op, const
DecimalScalar128& x,
auto t1 = std::make_shared<arrow::Decimal128Type>(x.precision(), x.scale());
auto t2 = std::make_shared<arrow::Decimal128Type>(y.precision(), y.scale());
bool overflow = false;
+ int64_t context = 0;
Decimal128TypePtr out_type;
EXPECT_OK(DecimalTypeUtil::GetResultType(op, {t1, t2}, &out_type));
@@ -101,6 +139,18 @@ void TestDecimalSql::Verify(DecimalTypeUtil::Op op, const
DecimalScalar128& x,
decimalops::Multiply(x, y, out_type->precision(), out_type->scale(),
&overflow);
break;
+ case DecimalTypeUtil::kOpDivide:
+ op_name = "divide";
+ out_value = decimalops::Divide(context, x, y, out_type->precision(),
+ out_type->scale(), &overflow);
+ break;
+
+ case DecimalTypeUtil::kOpMod:
+ op_name = "mod";
+ out_value = decimalops::Mod(context, x, y, out_type->precision(),
out_type->scale(),
+ &overflow);
+ break;
+
default:
// not implemented.
ASSERT_FALSE(true);
@@ -110,6 +160,33 @@ void TestDecimalSql::Verify(DecimalTypeUtil::Op op, const
DecimalScalar128& x,
overflow);
}
+void TestDecimalSql::VerifyAllSign(DecimalTypeUtil::Op op, const
DecimalScalar128& left,
+ const DecimalScalar128& right,
+ const DecimalScalar128& expected_output,
+ bool expected_overflow) {
+ // both +ve
+ Verify(op, left, right, expected_output, expected_overflow);
+
+ // left -ve
+ Verify(op, -left, right, -expected_output, expected_overflow);
+
+ if (op == DecimalTypeUtil::kOpMod) {
+ // right -ve
+ Verify(op, left, -right, expected_output, expected_overflow);
+
+ // both -ve
+ Verify(op, -left, -right, -expected_output, expected_overflow);
+ } else {
+ DCHECK(op == DecimalTypeUtil::kOpMultiply || op ==
DecimalTypeUtil::kOpDivide);
+
+ // right -ve
+ Verify(op, left, -right, -expected_output, expected_overflow);
+
+ // both -ve
+ Verify(op, -left, -right, expected_output, expected_overflow);
+ }
+}
+
TEST_F(TestDecimalSql, Add) {
// fast-path
AddAndVerify(DecimalScalar128{"201", 30, 3}, // x
@@ -156,28 +233,7 @@ TEST_F(TestDecimalSql, Subtract) {
DecimalScalar128{"-99999999999999999999999999999989999990", 38, 6});
}
-void TestDecimalSql::MultiplyAndVerifyAllSign(const DecimalScalar128& left,
- const DecimalScalar128& right,
- const DecimalScalar128&
expected_output,
- bool expected_overflow) {
- // both +ve
- MultiplyAndVerify(left, right, expected_output, expected_overflow);
-
- // left -ve
- MultiplyAndVerify(-left, right, -expected_output, expected_overflow);
-
- // right -ve
- MultiplyAndVerify(left, -right, -expected_output, expected_overflow);
-
- // both -ve
- MultiplyAndVerify(-left, -right, expected_output, expected_overflow);
-}
-
TEST_F(TestDecimalSql, Multiply) {
- const std::string thirty_five_9s(35, '9');
- const std::string thirty_six_9s(36, '9');
- const std::string thirty_eight_9s(38, '9');
-
// fast-path : out_precision < 38
MultiplyAndVerifyAllSign(DecimalScalar128{"201", 10, 3}, // x
DecimalScalar128{"301", 10, 2}, // y
@@ -207,21 +263,21 @@ TEST_F(TestDecimalSql, Multiply) {
// get trimmed).
MultiplyAndVerifyAllSign(
DecimalScalar128{"201", 20, 3}, // x
- DecimalScalar128{thirty_five_9s, 35, 2}, // y
+ DecimalScalar128{kThirtyFive9s, 35, 2}, // y
DecimalScalar128{"20099999999999999999999999999999999799", 38, 5}, //
expected
false); //
overflow
// out_precision == 38, very large values, no trimming of scale (scale <= 6
doesn't
// get trimmed). overflow expected.
- MultiplyAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x
- DecimalScalar128{thirty_six_9s, 35, 2}, // y
- DecimalScalar128{"0", 38, 5}, // expected
- true); // overflow
-
- MultiplyAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x
- DecimalScalar128{thirty_eight_9s, 35, 2}, // y
- DecimalScalar128{"0", 38, 5}, //
expected
- true); //
overflow
+ MultiplyAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x
+ DecimalScalar128{kThirtySix9s, 35, 2}, // y
+ DecimalScalar128{"0", 38, 5}, // expected
+ true); // overflow
+
+ MultiplyAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x
+ DecimalScalar128{kThirtyEight9s, 35, 2}, // y
+ DecimalScalar128{"0", 38, 5}, //
expected
+ true); //
overflow
// out_precision == 38, small input values, trimming of scale.
MultiplyAndVerifyAllSign(DecimalScalar128{"201", 20, 5}, // x
@@ -232,23 +288,23 @@ TEST_F(TestDecimalSql, Multiply) {
// out_precision == 38, large values, trimming of scale.
MultiplyAndVerifyAllSign(
DecimalScalar128{"201", 20, 5}, // x
- DecimalScalar128{thirty_five_9s, 35, 5}, // y
+ DecimalScalar128{kThirtyFive9s, 35, 5}, // y
DecimalScalar128{"2010000000000000000000000000000000", 38, 6}, //
expected
false); //
overflow
// out_precision == 38, very large values, trimming of scale (requires
convert to 256).
MultiplyAndVerifyAllSign(
- DecimalScalar128{thirty_five_9s, 38, 20}, // x
- DecimalScalar128{thirty_six_9s, 38, 20}, // y
+ DecimalScalar128{kThirtyFive9s, 38, 20}, // x
+ DecimalScalar128{kThirtySix9s, 38, 20}, // y
DecimalScalar128{"9999999999999999999999999999999999890", 38, 6}, //
expected
false); //
overflow
// out_precision == 38, very large values, trimming of scale (requires
convert to 256).
// should cause overflow.
- MultiplyAndVerifyAllSign(DecimalScalar128{thirty_five_9s, 38, 4}, // x
- DecimalScalar128{thirty_six_9s, 38, 4}, // y
- DecimalScalar128{"0", 38, 6}, //
expected
- true); //
overflow
+ MultiplyAndVerifyAllSign(DecimalScalar128{kThirtyFive9s, 38, 4}, // x
+ DecimalScalar128{kThirtySix9s, 38, 4}, // y
+ DecimalScalar128{"0", 38, 6}, // expected
+ true); // overflow
// corner cases.
MultiplyAndVerifyAllSign(
@@ -274,10 +330,153 @@ TEST_F(TestDecimalSql, Multiply) {
false); //
overflow
MultiplyAndVerifyAllSign(
- DecimalScalar128{thirty_five_9s, 38, 38}, // x
- DecimalScalar128{thirty_six_9s, 38, 38}, // y
+ DecimalScalar128{kThirtyFive9s, 38, 38}, // x
+ DecimalScalar128{kThirtySix9s, 38, 38}, // y
DecimalScalar128{"100000000000000000000000000000000", 38, 37}, //
expected
false); //
overflow
}
+TEST_F(TestDecimalSql, Divide) {
+ DivideAndVerifyAllSign(DecimalScalar128{"201", 10, 3}, // x
+ DecimalScalar128{"301", 10, 2}, // y
+ DecimalScalar128{"6677740863787", 23, 14}, //
expected
+ false); //
overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x
+ DecimalScalar128{"301", 20, 2}, // y
+ DecimalScalar128{"667774086378737542", 38, 19}, //
expected
+ false); //
overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x
+ DecimalScalar128{kThirtyFive9s, 35, 2}, // y
+ DecimalScalar128{"0", 38, 19}, // expected
+ false); // overflow
+
+ DivideAndVerifyAllSign(
+ DecimalScalar128{kThirtyFive9s, 35, 6}, // x
+ DecimalScalar128{"201", 20, 3}, // y
+ DecimalScalar128{"497512437810945273631840796019900493", 38, 6}, //
expected
+ false); //
overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{kThirtyEight9s, 38, 20}, // x
+ DecimalScalar128{kThirtyFive9s, 38, 20}, // y
+ DecimalScalar128{"1000000000", 38, 6}, // expected
+ false); // overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{"31939128063561476055", 38, 8}, // x
+ DecimalScalar128{"10000", 20, 0}, // y
+ DecimalScalar128{"3193912806356148", 38, 8}, //
expected
+ false);
+
+ // Corner cases
+ DivideAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 4}, // x
+ DecimalScalar128{0, UINT64_MAX, 38, 4}, // y
+ DecimalScalar128{"1000000", 38, 6}, // expected
+ false); // overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 4}, // x
+ DecimalScalar128{0, INT64_MAX, 38, 4}, // y
+ DecimalScalar128{"2000000", 38, 6}, // expected
+ false); // overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 19, 5}, //
x
+ DecimalScalar128{0, INT64_MAX, 19, 5}, //
y
+ DecimalScalar128{"20000000000000000001", 38, 19}, //
expected
+ false); //
overflow
+
+ DivideAndVerifyAllSign(DecimalScalar128{kThirtyFive9s, 38, 37}, // x
+ DecimalScalar128{kThirtyFive9s, 38, 38}, // y
+ DecimalScalar128{"10000000", 38, 6}, // expected
+ false); // overflow
+
+ // overflow
+ DivideAndVerifyAllSign(DecimalScalar128{kThirtyEight9s, 38, 6}, // x
+ DecimalScalar128{"201", 20, 3}, // y
+ DecimalScalar128{"0", 38, 6}, // expected
+ true);
+}
+
+TEST_F(TestDecimalSql, Mod) {
+ ModAndVerifyAllSign(DecimalScalar128{"201", 10, 3}, // x
+ DecimalScalar128{"301", 10, 2}, // y
+ DecimalScalar128{"201", 10, 3}, // expected
+ false); // overflow
+
+ ModAndVerify(DecimalScalar128{"201", 20, 2}, // x
+ DecimalScalar128{"301", 20, 3}, // y
+ DecimalScalar128{"204", 20, 3}, // expected
+ false); // overflow
+
+ ModAndVerifyAllSign(DecimalScalar128{"201", 20, 3}, // x
+ DecimalScalar128{kThirtyFive9s, 35, 2}, // y
+ DecimalScalar128{"201", 20, 3}, // expected
+ false); // overflow
+
+ ModAndVerifyAllSign(DecimalScalar128{kThirtyFive9s, 35, 6}, // x
+ DecimalScalar128{"201", 20, 3}, // y
+ DecimalScalar128{"180999", 23, 6}, // expected
+ false); // overflow
+
+ ModAndVerifyAllSign(DecimalScalar128{kThirtyEight9s, 38, 20}, // x
+ DecimalScalar128{kThirtyFive9s, 38, 21}, // y
+ DecimalScalar128{"9990", 38, 21}, // expected
+ false); // overflow
+
+ ModAndVerifyAllSign(DecimalScalar128{"31939128063561476055", 38, 8}, // x
+ DecimalScalar128{"10000", 20, 0}, // y
+ DecimalScalar128{"63561476055", 28, 8}, //
expected
+ false);
+
+ ModAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 4}, // x
+ DecimalScalar128{0, UINT64_MAX, 38, 4}, // y
+ DecimalScalar128{"0", 38, 4}, // expected
+ false); // overflow
+
+ ModAndVerifyAllSign(DecimalScalar128{0, UINT64_MAX, 38, 4}, // x
+ DecimalScalar128{0, INT64_MAX, 38, 4}, // y
+ DecimalScalar128{"1", 38, 4}, // expected
+ false); // overflow
+}
+
+TEST_F(TestDecimalSql, DivideByZero) {
+ gandiva::ExecutionContext context;
+ int32_t result_precision;
+ int32_t result_scale;
+ bool overflow;
+
+ // divide-by-zero should cause an error.
+ context.Reset();
+ result_precision = 38;
+ result_scale = 19;
+ decimalops::Divide(reinterpret_cast<int64>(&context),
DecimalScalar128{"201", 20, 3},
+ DecimalScalar128{"0", 20, 2}, result_precision,
result_scale,
+ &overflow);
+ EXPECT_TRUE(context.has_error());
+ EXPECT_EQ(context.get_error(), "divide by zero error");
+
+ // divide-by-nonzero should not cause an error.
+ context.Reset();
+ decimalops::Divide(reinterpret_cast<int64>(&context),
DecimalScalar128{"201", 20, 3},
+ DecimalScalar128{"1", 20, 2}, result_precision,
result_scale,
+ &overflow);
+ EXPECT_FALSE(context.has_error());
+
+ // mod-by-zero should cause an error.
+ context.Reset();
+ result_precision = 20;
+ result_scale = 3;
+ decimalops::Mod(reinterpret_cast<int64>(&context), DecimalScalar128{"201",
20, 3},
+ DecimalScalar128{"0", 20, 2}, result_precision, result_scale,
+ &overflow);
+ EXPECT_TRUE(context.has_error());
+ EXPECT_EQ(context.get_error(), "divide by zero error");
+
+ // mod-by-nonzero should not cause an error.
+ context.Reset();
+ decimalops::Mod(reinterpret_cast<int64>(&context), DecimalScalar128{"201",
20, 3},
+ DecimalScalar128{"1", 20, 2}, result_precision, result_scale,
+ &overflow);
+ EXPECT_FALSE(context.has_error());
+}
+
} // namespace gandiva
diff --git a/cpp/src/gandiva/precompiled/decimal_wrapper.cc
b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
index 1066c5c..d5c919e 100644
--- a/cpp/src/gandiva/precompiled/decimal_wrapper.cc
+++ b/cpp/src/gandiva/precompiled/decimal_wrapper.cc
@@ -52,4 +52,38 @@ void multiply_internal_decimal128_decimal128(int64_t x_high,
uint64_t x_low,
*out_low = out.low_bits();
}
+FORCE_INLINE
+void divide_internal_decimal128_decimal128(
+ int64_t context, int64_t x_high, uint64_t x_low, int32_t x_precision,
int32_t x_scale,
+ int64_t y_high, uint64_t y_low, int32_t y_precision, int32_t y_scale,
+ int32_t out_precision, int32_t out_scale, int64_t* out_high, uint64_t*
out_low) {
+ gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale);
+ gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale);
+ bool overflow;
+
+ // TODO ravindra: generate error on overflows (ARROW-4570).
+ arrow::BasicDecimal128 out =
+ gandiva::decimalops::Divide(context, x, y, out_precision, out_scale,
&overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
+FORCE_INLINE
+void mod_internal_decimal128_decimal128(int64_t context, int64_t x_high,
uint64_t x_low,
+ int32_t x_precision, int32_t x_scale,
+ int64_t y_high, uint64_t y_low,
+ int32_t y_precision, int32_t y_scale,
+ int32_t out_precision, int32_t
out_scale,
+ int64_t* out_high, uint64_t* out_low) {
+ gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale);
+ gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale);
+ bool overflow;
+
+ // TODO ravindra: generate error on overflows (ARROW-4570).
+ arrow::BasicDecimal128 out =
+ gandiva::decimalops::Mod(context, x, y, out_precision, out_scale,
&overflow);
+ *out_high = out.high_bits();
+ *out_low = out.low_bits();
+}
+
} // extern "C"
diff --git a/cpp/src/gandiva/tests/decimal_single_test.cc
b/cpp/src/gandiva/tests/decimal_single_test.cc
index 0cc93e7..c28c47e 100644
--- a/cpp/src/gandiva/tests/decimal_single_test.cc
+++ b/cpp/src/gandiva/tests/decimal_single_test.cc
@@ -66,6 +66,16 @@ class TestDecimalOps : public ::testing::Test {
Verify(DecimalTypeUtil::kOpMultiply, "multiply", x, y, expected);
}
+ void DivideAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected) {
+ Verify(DecimalTypeUtil::kOpDivide, "divide", x, y, expected);
+ }
+
+ void ModAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
+ const DecimalScalar128& expected) {
+ Verify(DecimalTypeUtil::kOpMod, "mod", x, y, expected);
+ }
+
protected:
arrow::MemoryPool* pool_;
};
@@ -258,7 +268,8 @@ TEST_F(TestDecimalOps, TestSubtract) {
decimal_literal("-3211", 32, 3)); // expected
}
-// Lots of unit tests for multiply in decimal_ops_test.cc. So, keeping this
basic.
+// Lots of unit tests for multiply/divide/mod in decimal_ops_test.cc. So,
keeping these
+// basic.
TEST_F(TestDecimalOps, TestMultiply) {
// fast-path
MultiplyAndVerify(decimal_literal("201", 10, 3), // x
@@ -271,4 +282,24 @@ TEST_F(TestDecimalOps, TestMultiply) {
DecimalScalar128("9999999999999999999999999999999999890",
38, 6));
}
+TEST_F(TestDecimalOps, TestDivide) {
+ DivideAndVerify(decimal_literal("201", 10, 3), // x
+ decimal_literal("301", 10, 2), // y
+ decimal_literal("6677740863787", 23, 14)); // expected
+
+ DivideAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20), // x
+ DecimalScalar128(std::string(35, '9'), 38, 20), // x
+ DecimalScalar128("1000000000", 38, 6));
+}
+
+TEST_F(TestDecimalOps, TestMod) {
+ ModAndVerify(decimal_literal("201", 20, 2), // x
+ decimal_literal("301", 20, 3), // y
+ decimal_literal("204", 20, 3)); // expected
+
+ ModAndVerify(DecimalScalar128(std::string(38, '9'), 38, 20), // x
+ DecimalScalar128(std::string(35, '9'), 38, 21), // x
+ DecimalScalar128("9990", 38, 21));
+}
+
} // namespace gandiva