This is an automated email from the ASF dual-hosted git repository.

emkornfield pushed a commit to branch decimal256
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/decimal256 by this push:
     new e6dc833  ARROW-10102: [C++] Refactor BasicDecimal128 Multiplication to 
use unsigned helper
e6dc833 is described below

commit e6dc83343d47c0c7c4ecc0a547359652defd69e2
Author: Ezra <[email protected]>
AuthorDate: Thu Oct 1 09:00:28 2020 -0700

    ARROW-10102: [C++] Refactor BasicDecimal128 Multiplication to use unsigned 
helper
    
    Closes #8279 from Luminarys/master
    
    Authored-by: Ezra <[email protected]>
    Signed-off-by: Micah Kornfield <[email protected]>
---
 cpp/src/arrow/util/basic_decimal.cc | 110 ++++++++++++++++++++++++++----------
 cpp/src/arrow/util/decimal_test.cc  |   8 +++
 2 files changed, 88 insertions(+), 30 deletions(-)

diff --git a/cpp/src/arrow/util/basic_decimal.cc 
b/cpp/src/arrow/util/basic_decimal.cc
index 3e7daa3..ac85bd0 100644
--- a/cpp/src/arrow/util/basic_decimal.cc
+++ b/cpp/src/arrow/util/basic_decimal.cc
@@ -28,6 +28,7 @@
 #include <string>
 
 #include "arrow/util/bit_util.h"
+#include "arrow/util/int128_internal.h"
 #include "arrow/util/int_util_internal.h"
 #include "arrow/util/logging.h"
 #include "arrow/util/macros.h"
@@ -119,8 +120,11 @@ static const BasicDecimal128 ScaleMultipliersHalf[] = {
     BasicDecimal128(271050543121376108LL, 9257742014424809472ULL),
     BasicDecimal128(2710505431213761085LL, 343699775700336640ULL)};
 
+#ifdef ARROW_USE_NATIVE_INT128
+static constexpr uint64_t kInt64Mask = 0xFFFFFFFFFFFFFFFF;
+#else
 static constexpr uint64_t kIntMask = 0xFFFFFFFF;
-static constexpr auto kCarryBit = static_cast<uint64_t>(1) << 
static_cast<uint64_t>(32);
+#endif
 
 // same as ScaleMultipliers[38] - 1
 static constexpr BasicDecimal128 kMaxValue =
@@ -248,40 +252,86 @@ BasicDecimal128& BasicDecimal128::operator>>=(uint32_t 
bits) {
   return *this;
 }
 
-BasicDecimal128& BasicDecimal128::operator*=(const BasicDecimal128& right) {
-  // Break the left and right numbers into 32 bit chunks
-  // so that we can multiply them without overflow.
-  const uint64_t L0 = static_cast<uint64_t>(high_bits_) >> 32;
-  const uint64_t L1 = static_cast<uint64_t>(high_bits_) & kIntMask;
-  const uint64_t L2 = low_bits_ >> 32;
-  const uint64_t L3 = low_bits_ & kIntMask;
-
-  const uint64_t R0 = static_cast<uint64_t>(right.high_bits_) >> 32;
-  const uint64_t R1 = static_cast<uint64_t>(right.high_bits_) & kIntMask;
-  const uint64_t R2 = right.low_bits_ >> 32;
-  const uint64_t R3 = right.low_bits_ & kIntMask;
+namespace {
 
-  uint64_t product = L3 * R3;
-  low_bits_ = product & kIntMask;
-
-  uint64_t sum = product >> 32;
-
-  product = L2 * R3;
-  sum += product;
-  high_bits_ = static_cast<int64_t>(sum < product ? kCarryBit : 0);
+// TODO: Remove this guard once it's used by BasicDecimal256
+#ifndef ARROW_USE_NATIVE_INT128
+// This method losslessly multiplies x and y into a 128 bit unsigned integer
+// whose high bits will be stored in hi and low bits in lo.
+void ExtendAndMultiplyUint64(uint64_t x, uint64_t y, uint64_t* hi, uint64_t* 
lo) {
+#ifdef ARROW_USE_NATIVE_INT128
+  const __uint128_t r = static_cast<__uint128_t>(x) * y;
+  *lo = r & kInt64Mask;
+  *hi = r >> 64;
+#else
+  // If we can't use a native fallback, perform multiplication
+  // by splitting up x and y into 32 bit high/low bit components,
+  // allowing us to represent the multiplication as
+  // x * y = x_lo * y_lo + x_hi * y_lo * 2^32 + y_hi * x_lo * 2^32
+  // + x_hi * y_hi * 2^64.
+  //
+  // Now, consider the final output as lo_lo || lo_hi || hi_lo || hi_hi.
+  // Therefore,
+  // lo_lo is (x_lo * y_lo)_lo,
+  // lo_hi is ((x_lo * y_lo)_hi + (x_hi * y_lo)_lo + (x_lo * y_hi)_lo)_lo,
+  // hi_lo is ((x_hi * y_hi)_lo + (x_hi * y_lo)_hi + (x_lo * y_hi)_hi)_hi,
+  // hi_hi is (x_hi * y_hi)_hi
+  const uint64_t x_lo = x & kIntMask;
+  const uint64_t y_lo = y & kIntMask;
+  const uint64_t x_hi = x >> 32;
+  const uint64_t y_hi = y >> 32;
+
+  const uint64_t t = x_lo * y_lo;
+  const uint64_t t_lo = t & kIntMask;
+  const uint64_t t_hi = t >> 32;
+
+  const uint64_t u = x_hi * y_lo + t_hi;
+  const uint64_t u_lo = u & kIntMask;
+  const uint64_t u_hi = u >> 32;
+
+  const uint64_t v = x_lo * y_hi + u_lo;
+  const uint64_t v_hi = v >> 32;
+
+  *hi = x_hi * y_hi + u_hi + v_hi;
+  *lo = (v << 32) | t_lo;
+#endif
+}
+#endif
 
-  product = L3 * R2;
-  sum += product;
+void MultiplyUint128(uint64_t x_hi, uint64_t x_lo, uint64_t y_hi, uint64_t 
y_lo,
+                     uint64_t* hi, uint64_t* lo) {
+#ifdef ARROW_USE_NATIVE_INT128
+  const __uint128_t x = (static_cast<__uint128_t>(x_hi) << 64) | x_lo;
+  const __uint128_t y = (static_cast<__uint128_t>(y_hi) << 64) | y_lo;
+  const __uint128_t r = x * y;
+  *lo = r & kInt64Mask;
+  *hi = r >> 64;
+#else
+  // To perform 128 bit multiplication without a native fallback
+  // we first perform lossless 64 bit multiplication of the low
+  // bits, and then add x_hi * y_lo and x_lo * y_hi to the high
+  // bits. Note that we can skip adding x_hi * y_hi because it
+  // always will be over 128 bits.
+  ExtendAndMultiplyUint64(x_lo, y_lo, hi, lo);
+  *hi += (x_hi * y_lo) + (x_lo * y_hi);
+#endif
+}
 
-  low_bits_ += sum << 32;
+}  // namespace
 
-  if (sum < product) {
-    high_bits_ += kCarryBit;
+BasicDecimal128& BasicDecimal128::operator*=(const BasicDecimal128& right) {
+  // Since the max value of BasicDecimal128 is supposed to be 1e38 - 1 and the
+  // min the negation taking the absolute values here should always be safe.
+  const bool negate = Sign() != right.Sign();
+  BasicDecimal128 x = BasicDecimal128::Abs(*this);
+  BasicDecimal128 y = BasicDecimal128::Abs(right);
+  uint64_t hi;
+  MultiplyUint128(x.high_bits(), x.low_bits(), y.high_bits(), y.low_bits(), 
&hi,
+                  &low_bits_);
+  high_bits_ = hi;
+  if (negate) {
+    Negate();
   }
-
-  high_bits_ += static_cast<int64_t>(sum >> 32);
-  high_bits_ += L1 * R3 + L2 * R2 + L3 * R1;
-  high_bits_ += (L0 * R3 + L1 * R2 + L2 * R1 + L3 * R0) << 32;
   return *this;
 }
 
diff --git a/cpp/src/arrow/util/decimal_test.cc 
b/cpp/src/arrow/util/decimal_test.cc
index abd4c2a..372dbee 100644
--- a/cpp/src/arrow/util/decimal_test.cc
+++ b/cpp/src/arrow/util/decimal_test.cc
@@ -922,6 +922,14 @@ TEST(Decimal128Test, Multiply) {
       Decimal128 result = Decimal128(x) * Decimal128(y);
       ASSERT_EQ(Decimal128(static_cast<int64_t>(x) * y), result)
           << " x: " << x << " y: " << y;
+      // Test by multiplying with an additional 32 bit factor, then additional
+      // factor of 2^30 to test results in the range of -2^123 to 2^123 
without overflow.
+      for (auto z : GetRandomNumbers<Int32Type>(32)) {
+        int128_t w = static_cast<int128_t>(x) * y * (1ull << 30);
+        Decimal128 expected = Decimal128FromInt128(static_cast<int128_t>(w) * 
z);
+        Decimal128 actual = Decimal128FromInt128(w) * Decimal128(z);
+        ASSERT_EQ(expected, actual) << " w: " << x << " * " << y << " * 2^30 
z: " << z;
+      }
     }
   }
 

Reply via email to