This is an automated email from the ASF dual-hosted git repository.
emkornfield pushed a commit to branch decimal256
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/decimal256 by this push:
new e6dc833 ARROW-10102: [C++] Refactor BasicDecimal128 Multiplication to
use unsigned helper
e6dc833 is described below
commit e6dc83343d47c0c7c4ecc0a547359652defd69e2
Author: Ezra <[email protected]>
AuthorDate: Thu Oct 1 09:00:28 2020 -0700
ARROW-10102: [C++] Refactor BasicDecimal128 Multiplication to use unsigned
helper
Closes #8279 from Luminarys/master
Authored-by: Ezra <[email protected]>
Signed-off-by: Micah Kornfield <[email protected]>
---
cpp/src/arrow/util/basic_decimal.cc | 110 ++++++++++++++++++++++++++----------
cpp/src/arrow/util/decimal_test.cc | 8 +++
2 files changed, 88 insertions(+), 30 deletions(-)
diff --git a/cpp/src/arrow/util/basic_decimal.cc
b/cpp/src/arrow/util/basic_decimal.cc
index 3e7daa3..ac85bd0 100644
--- a/cpp/src/arrow/util/basic_decimal.cc
+++ b/cpp/src/arrow/util/basic_decimal.cc
@@ -28,6 +28,7 @@
#include <string>
#include "arrow/util/bit_util.h"
+#include "arrow/util/int128_internal.h"
#include "arrow/util/int_util_internal.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
@@ -119,8 +120,11 @@ static const BasicDecimal128 ScaleMultipliersHalf[] = {
BasicDecimal128(271050543121376108LL, 9257742014424809472ULL),
BasicDecimal128(2710505431213761085LL, 343699775700336640ULL)};
+#ifdef ARROW_USE_NATIVE_INT128
+static constexpr uint64_t kInt64Mask = 0xFFFFFFFFFFFFFFFF;
+#else
static constexpr uint64_t kIntMask = 0xFFFFFFFF;
-static constexpr auto kCarryBit = static_cast<uint64_t>(1) <<
static_cast<uint64_t>(32);
+#endif
// same as ScaleMultipliers[38] - 1
static constexpr BasicDecimal128 kMaxValue =
@@ -248,40 +252,86 @@ BasicDecimal128& BasicDecimal128::operator>>=(uint32_t
bits) {
return *this;
}
-BasicDecimal128& BasicDecimal128::operator*=(const BasicDecimal128& right) {
- // Break the left and right numbers into 32 bit chunks
- // so that we can multiply them without overflow.
- const uint64_t L0 = static_cast<uint64_t>(high_bits_) >> 32;
- const uint64_t L1 = static_cast<uint64_t>(high_bits_) & kIntMask;
- const uint64_t L2 = low_bits_ >> 32;
- const uint64_t L3 = low_bits_ & kIntMask;
-
- const uint64_t R0 = static_cast<uint64_t>(right.high_bits_) >> 32;
- const uint64_t R1 = static_cast<uint64_t>(right.high_bits_) & kIntMask;
- const uint64_t R2 = right.low_bits_ >> 32;
- const uint64_t R3 = right.low_bits_ & kIntMask;
+namespace {
- uint64_t product = L3 * R3;
- low_bits_ = product & kIntMask;
-
- uint64_t sum = product >> 32;
-
- product = L2 * R3;
- sum += product;
- high_bits_ = static_cast<int64_t>(sum < product ? kCarryBit : 0);
+// TODO: Remove this guard once it's used by BasicDecimal256
+#ifndef ARROW_USE_NATIVE_INT128
+// This method losslessly multiplies x and y into a 128 bit unsigned integer
+// whose high bits will be stored in hi and low bits in lo.
+void ExtendAndMultiplyUint64(uint64_t x, uint64_t y, uint64_t* hi, uint64_t*
lo) {
+#ifdef ARROW_USE_NATIVE_INT128
+ const __uint128_t r = static_cast<__uint128_t>(x) * y;
+ *lo = r & kInt64Mask;
+ *hi = r >> 64;
+#else
+ // If we can't use a native fallback, perform multiplication
+ // by splitting up x and y into 32 bit high/low bit components,
+ // allowing us to represent the multiplication as
+ // x * y = x_lo * y_lo + x_hi * y_lo * 2^32 + y_hi * x_lo * 2^32
+ // + x_hi * y_hi * 2^64.
+ //
+ // Now, consider the final output as lo_lo || lo_hi || hi_lo || hi_hi.
+ // Therefore,
+ // lo_lo is (x_lo * y_lo)_lo,
+ // lo_hi is ((x_lo * y_lo)_hi + (x_hi * y_lo)_lo + (x_lo * y_hi)_lo)_lo,
+ // hi_lo is ((x_hi * y_hi)_lo + (x_hi * y_lo)_hi + (x_lo * y_hi)_hi)_hi,
+ // hi_hi is (x_hi * y_hi)_hi
+ const uint64_t x_lo = x & kIntMask;
+ const uint64_t y_lo = y & kIntMask;
+ const uint64_t x_hi = x >> 32;
+ const uint64_t y_hi = y >> 32;
+
+ const uint64_t t = x_lo * y_lo;
+ const uint64_t t_lo = t & kIntMask;
+ const uint64_t t_hi = t >> 32;
+
+ const uint64_t u = x_hi * y_lo + t_hi;
+ const uint64_t u_lo = u & kIntMask;
+ const uint64_t u_hi = u >> 32;
+
+ const uint64_t v = x_lo * y_hi + u_lo;
+ const uint64_t v_hi = v >> 32;
+
+ *hi = x_hi * y_hi + u_hi + v_hi;
+ *lo = (v << 32) | t_lo;
+#endif
+}
+#endif
- product = L3 * R2;
- sum += product;
+void MultiplyUint128(uint64_t x_hi, uint64_t x_lo, uint64_t y_hi, uint64_t
y_lo,
+ uint64_t* hi, uint64_t* lo) {
+#ifdef ARROW_USE_NATIVE_INT128
+ const __uint128_t x = (static_cast<__uint128_t>(x_hi) << 64) | x_lo;
+ const __uint128_t y = (static_cast<__uint128_t>(y_hi) << 64) | y_lo;
+ const __uint128_t r = x * y;
+ *lo = r & kInt64Mask;
+ *hi = r >> 64;
+#else
+ // To perform 128 bit multiplication without a native fallback
+ // we first perform lossless 64 bit multiplication of the low
+ // bits, and then add x_hi * y_lo and x_lo * y_hi to the high
+ // bits. Note that we can skip adding x_hi * y_hi because it
+ // always will be over 128 bits.
+ ExtendAndMultiplyUint64(x_lo, y_lo, hi, lo);
+ *hi += (x_hi * y_lo) + (x_lo * y_hi);
+#endif
+}
- low_bits_ += sum << 32;
+} // namespace
- if (sum < product) {
- high_bits_ += kCarryBit;
+BasicDecimal128& BasicDecimal128::operator*=(const BasicDecimal128& right) {
+ // Since the max value of BasicDecimal128 is supposed to be 1e38 - 1 and the
+ // min the negation taking the absolute values here should always be safe.
+ const bool negate = Sign() != right.Sign();
+ BasicDecimal128 x = BasicDecimal128::Abs(*this);
+ BasicDecimal128 y = BasicDecimal128::Abs(right);
+ uint64_t hi;
+ MultiplyUint128(x.high_bits(), x.low_bits(), y.high_bits(), y.low_bits(),
&hi,
+ &low_bits_);
+ high_bits_ = hi;
+ if (negate) {
+ Negate();
}
-
- high_bits_ += static_cast<int64_t>(sum >> 32);
- high_bits_ += L1 * R3 + L2 * R2 + L3 * R1;
- high_bits_ += (L0 * R3 + L1 * R2 + L2 * R1 + L3 * R0) << 32;
return *this;
}
diff --git a/cpp/src/arrow/util/decimal_test.cc
b/cpp/src/arrow/util/decimal_test.cc
index abd4c2a..372dbee 100644
--- a/cpp/src/arrow/util/decimal_test.cc
+++ b/cpp/src/arrow/util/decimal_test.cc
@@ -922,6 +922,14 @@ TEST(Decimal128Test, Multiply) {
Decimal128 result = Decimal128(x) * Decimal128(y);
ASSERT_EQ(Decimal128(static_cast<int64_t>(x) * y), result)
<< " x: " << x << " y: " << y;
+ // Test by multiplying with an additional 32 bit factor, then additional
+ // factor of 2^30 to test results in the range of -2^123 to 2^123
without overflow.
+ for (auto z : GetRandomNumbers<Int32Type>(32)) {
+ int128_t w = static_cast<int128_t>(x) * y * (1ull << 30);
+ Decimal128 expected = Decimal128FromInt128(static_cast<int128_t>(w) *
z);
+ Decimal128 actual = Decimal128FromInt128(w) * Decimal128(z);
+ ASSERT_EQ(expected, actual) << " w: " << x << " * " << y << " * 2^30
z: " << z;
+ }
}
}