bkietz commented on a change in pull request #10364:
URL: https://github.com/apache/arrow/pull/10364#discussion_r641132241



##########
File path: cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
##########
@@ -516,6 +520,141 @@ std::shared_ptr<ScalarFunction> 
MakeUnarySignedArithmeticFunctionNotNull(
   return func;
 }
 
+// Decimal arithmetics
+struct BinaryDecimal : public KernelState {
+  const std::shared_ptr<DecimalType> left_type, right_type;
+  std::shared_ptr<DataType> out_type;
+
+  explicit BinaryDecimal(const KernelInitArgs& args)
+      : left_type(checked_pointer_cast<DecimalType>(args.inputs[0].type)),
+        right_type(checked_pointer_cast<DecimalType>(args.inputs[1].type)) {
+    DCHECK_EQ(left_type->id(), right_type->id());
+  }
+
+  // create instance of derived class T
+  template <typename T>
+  static Result<std::unique_ptr<KernelState>> Make(const KernelInitArgs& args) 
{
+    auto op = ::arrow::internal::make_unique<T>(args);
+    if (op->left_type->scale() < 0 || op->right_type->scale() < 0) {
+      return Status::Invalid("Decimals with negative scales not supported");
+    }
+    RETURN_NOT_OK(op->Init(op->left_type->precision(), op->left_type->scale(),
+                           op->right_type->precision(), 
op->right_type->scale()));
+    return std::move(op);
+  }
+
+  // return error and stop kernel execution if output precision is out of bound
+  Status Init(int32_t out_prec, int32_t out_scale) {
+    if (left_type->id() == Type::DECIMAL128) {
+      ARROW_ASSIGN_OR_RAISE(out_type, Decimal128Type::Make(out_prec, 
out_scale));
+    } else {
+      ARROW_ASSIGN_OR_RAISE(out_type, Decimal256Type::Make(out_prec, 
out_scale));
+    }
+    return Status::OK();
+  }
+
+  Result<std::shared_ptr<DataType>> ResolveOutput(const 
std::vector<ValueDescr>&) const {
+    return out_type;
+  }
+};
+
+template <bool IsSubtract>
+struct AddOrSubtractDecimal : public BinaryDecimal {
+  using BinaryDecimal::BinaryDecimal;
+
+  int32_t left_scaleup, right_scaleup;
+
+  // called by kernel::init()
+  static Result<std::unique_ptr<KernelState>> Make(KernelContext*,
+                                                   const KernelInitArgs& args) 
{
+    return BinaryDecimal::Make<AddOrSubtractDecimal<IsSubtract>>(args);
+  }
+
+  // figure out output type and arg scaling, called by Make()
+  Status Init(int32_t p1, int32_t s1, int32_t p2, int32_t s2) {
+    const int32_t out_scale = std::max(s1, s2);
+    const int32_t out_prec = std::max(p1 - s1, p2 - s2) + 1 + out_scale;
+    left_scaleup = out_scale - s1;
+    right_scaleup = out_scale - s2;
+    return BinaryDecimal::Init(out_prec, out_scale);
+  }
+
+  // called by kerne::exec() for each value pair
+  // TODO(yibo): avoid repeat rescaling of scalar arg
+  template <typename T, typename Arg0, typename Arg1>
+  T Call(KernelContext*, Arg0 left, Arg1 right, Status*) const {
+    if (left_scaleup > 0) left = left.IncreaseScaleBy(left_scaleup);
+    if (right_scaleup > 0) right = right.IncreaseScaleBy(right_scaleup);
+    if (IsSubtract) right = -right;
+    return left + right;
+  }
+};
+
+using AddDecimal = AddOrSubtractDecimal</*IsSubtract=*/false>;
+using SubtractDecimal = AddOrSubtractDecimal</*IsSubtract=*/true>;
+
+struct MultiplyDecimal : public BinaryDecimal {
+  using BinaryDecimal::BinaryDecimal;
+
+  static Result<std::unique_ptr<KernelState>> Make(KernelContext*,
+                                                   const KernelInitArgs& args) 
{
+    return BinaryDecimal::Make<MultiplyDecimal>(args);
+  }
+
+  Status Init(int32_t p1, int32_t s1, int32_t p2, int32_t s2) {
+    return BinaryDecimal::Init(p1 + p2 + 1, s1 + s2);
+  }
+
+  template <typename T, typename Arg0, typename Arg1>
+  T Call(KernelContext*, Arg0 left, Arg1 right, Status*) const {
+    return left * right;
+  }
+};
+
+struct DivideDecimal : public BinaryDecimal {
+  using BinaryDecimal::BinaryDecimal;
+
+  int32_t left_scaleup;
+
+  static Result<std::unique_ptr<KernelState>> Make(KernelContext*,
+                                                   const KernelInitArgs& args) 
{
+    return BinaryDecimal::Make<DivideDecimal>(args);
+  }
+
+  Status Init(int32_t p1, int32_t s1, int32_t p2, int32_t s2) {
+    // 
https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html

Review comment:
       Please write out an explanation of the upscaling behavior either as a 
comment or in compute.rst

##########
File path: cpp/src/arrow/compute/kernels/test_util.h
##########
@@ -113,6 +113,9 @@ void CheckScalarBinary(std::string func_name, 
std::shared_ptr<Array> left_input,
                        std::shared_ptr<Array> expected,
                        const FunctionOptions* options = nullptr);
 
+void CheckScalarGeneral(std::string func_name, const std::vector<Datum>& 
inputs,

Review comment:
       Why not CheckScalar()?

##########
File path: cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
##########
@@ -1161,5 +1161,312 @@ TYPED_TEST(TestUnaryArithmeticFloating, AbsoluteValue) {
   }
 }
 
+class TestBinaryArithmeticDecimal : public TestBase {
+ protected:
+  struct Arg {
+    std::shared_ptr<DataType> type;
+    std::string value;
+  };
+
+  std::shared_ptr<DataType> GetOutType(const std::string& op,
+                                       const std::shared_ptr<DataType>& 
left_type,
+                                       const std::shared_ptr<DataType>& 
right_type) {
+    auto left_decimal_type = std::static_pointer_cast<DecimalType>(left_type);
+    auto right_decimal_type = 
std::static_pointer_cast<DecimalType>(right_type);
+
+    const int32_t p1 = left_decimal_type->precision(), s1 = 
left_decimal_type->scale();
+    const int32_t p2 = right_decimal_type->precision(), s2 = 
right_decimal_type->scale();
+
+    // 
https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html
+    int32_t precision, scale;
+    if (op == "add" || op == "subtract") {
+      scale = std::max(s1, s2);
+      precision = std::max(p1 - s1, p2 - s2) + 1 + scale;
+    } else if (op == "multiply") {
+      scale = s1 + s2;
+      precision = p1 + p2 + 1;
+    } else if (op == "divide") {
+      scale = std::max(4, s1 + p2 - s2 + 1);
+      precision = p1 - s1 + s2 + scale;
+    } else {
+      ABORT_NOT_OK(Status::Invalid("invalid binary operator: ", op));
+    }
+
+    std::shared_ptr<DataType> type;
+    if (left_type->id() == Type::DECIMAL128) {
+      ASSIGN_OR_ABORT(type, Decimal128Type::Make(precision, scale));
+    } else {
+      ASSIGN_OR_ABORT(type, Decimal256Type::Make(precision, scale));
+    }
+    return type;
+  }
+
+  std::shared_ptr<Scalar> MakeScalar(const std::shared_ptr<DataType>& type,
+                                     const std::string& str) {
+    std::shared_ptr<Scalar> scalar;
+    if (type->id() == Type::DECIMAL128) {
+      Decimal128 value;
+      int32_t dummy;
+      ABORT_NOT_OK(Decimal128::FromString(str, &value, &dummy));
+      ASSIGN_OR_ABORT(scalar, arrow::MakeScalar(type, value));
+    } else {
+      Decimal256 value;
+      int32_t dummy;
+      ABORT_NOT_OK(Decimal256::FromString(str, &value, &dummy));
+      ASSIGN_OR_ABORT(scalar, arrow::MakeScalar(type, value));
+    }
+    return scalar;
+  }
+
+  Datum ToDatum(const std::shared_ptr<DataType>& type, const std::string& 
value) {
+    if (value.find("[") == std::string::npos) {
+      return Datum(MakeScalar(type, value));
+    } else {
+      return Datum(ArrayFromJSON(type, value));
+    }
+  }
+
+  void Assert(const std::string& op, const Arg& left, const Arg& right,
+              const std::string& expected) {
+    const Datum arg0 = ToDatum(left.type, left.value);
+    const Datum arg1 = ToDatum(right.type, right.value);
+
+    auto out_type = GetOutType(op, left.type, right.type);
+    CheckScalarGeneral(op, {arg0, arg1}, ToDatum(out_type, expected), 
&options_);
+
+    // commutative operations
+    if (op == "add" || op == "multiply") {
+      CheckScalarGeneral(op, {arg1, arg0}, ToDatum(out_type, expected), 
&options_);
+    }
+  }
+
+  void AssertFail(const std::string& op, const Arg& left, const Arg& right) {
+    const Datum arg0 = ToDatum(left.type, left.value);
+    const Datum arg1 = ToDatum(right.type, right.value);
+
+    ASSERT_NOT_OK(CallFunction(op, {arg0, arg1}, &options_));
+    if (op == "add" || op == "multiply") {
+      ASSERT_NOT_OK(CallFunction(op, {arg1, arg0}, &options_));
+    }
+  }
+
+  ArithmeticOptions options_ = ArithmeticOptions();
+};
+
+// reference result from bc (precsion=100, scale=40)
+TEST_F(TestBinaryArithmeticDecimal, AddSubtract) {
+  Arg left, right;
+  std::string added, subtracted;
+
+  // array array, decimal128
+  // clang-format off
+  left = {
+    decimal128(30, 3),
+    R"([
+      "1.000",
+      "-123456789012345678901234567.890",
+      "98765432109876543210.987",
+      "-999999999999999999999999999.999"
+    ])",
+  };
+  right = {
+    decimal128(20, 9),
+    R"([
+      "-1.000000000",
+      "12345678901.234567890",
+      "98765.432101234",
+      "-99999999999.999999999"
+    ])",
+  };
+  added = R"([
+    "0.000000000",
+    "-123456789012345666555555666.655432110",
+    "98765432109876641976.419101234",
+    "-1000000000000000099999999999.998999999"
+  ])";
+  subtracted = R"([
+    "2.000000000",
+    "-123456789012345691246913469.124567890",
+    "98765432109876444445.554898766",
+    "-999999999999999899999999999.999000001"
+  ])";
+  this->Assert("add", left, right, added);
+  this->Assert("subtract", left, right, subtracted);
+
+  // array array, decimal256
+  left = {
+    decimal256(30, 20),
+    R"([
+      "-1.00000000000000000001",
+      "1234567890.12345678900000000000",
+      "-9876543210.09876543210987654321",
+      "9999999999.99999999999999999999"
+    ])",
+  };
+  right = {
+    decimal256(30, 10),
+    R"([
+      "1.0000000000",
+      "-1234567890.1234567890",
+      "6789.5432101234",
+      "99999999999999999999.9999999999"
+    ])",
+  };
+  added = R"([
+    "-0.00000000000000000001",
+    "0.00000000000000000000",
+    "-9876536420.55555530870987654321",
+    "100000000009999999999.99999999989999999999"
+  ])";
+  subtracted = R"([
+    "-2.00000000000000000001",
+    "2469135780.24691357800000000000",
+    "-9876549999.64197555550987654321",
+    "-99999999989999999999.99999999990000000001"
+  ])";
+  this->Assert("add", left, right, added);
+  this->Assert("subtract", left, right, subtracted);
+  // clang-format on
+
+  // scalar array
+  left = {decimal128(6, 1), "12345.6"};
+  right = {decimal128(10, 3), R"(["1.234", "1234.000", "-9876.543", 
"666.888"])"};
+  added = R"(["12346.834", "13579.600", "2469.057", "13012.488"])";
+  subtracted = R"(["12344.366", "11111.600", "22222.143", "11678.712"])";
+  this->Assert("add", left, right, added);
+  this->Assert("subtract", left, right, subtracted);
+  // right - left
+  subtracted = R"(["-12344.366", "-11111.600", "-22222.143", "-11678.712"])";
+  this->Assert("subtract", right, left, subtracted);
+
+  // scalar scalar
+  left = {decimal256(3, 0), "666"};
+  right = {decimal256(3, 0), "888"};
+  this->Assert("add", left, right, "1554");
+  this->Assert("subtract", left, right, "-222");
+
+  // failed case: result *maybe* overflow
+  left = {decimal128(21, 20), "0.12345678901234567890"};
+  right = {decimal128(21, 1), "1.0"};
+  this->AssertFail("add", left, right);
+  this->AssertFail("subtract", left, right);
+
+  left = {decimal256(75, 0), "0"};
+  right = {decimal256(2, 1), "0.0"};
+  this->AssertFail("add", left, right);
+  this->AssertFail("subtract", left, right);
+}
+
+TEST_F(TestBinaryArithmeticDecimal, Multiply) {
+  Arg left, right;
+  std::string expected;
+
+  // array array
+  // clang-format off
+  left = {
+    decimal128(20, 10),
+    R"([
+      "1234567890.1234567890",
+      "-0.0000000001",
+      "-9999999999.9999999999"
+    ])",
+  };
+  right = {
+    decimal128(13, 3),
+    R"([
+      "1234567890.123",
+      "0.001",
+      "-9999999999.999"
+    ])",
+  };
+  expected = R"([
+    "1524157875323319737.9870903950470",
+    "-0.0000000000001",
+    "99999999999989999999.0000000000001"
+  ])";
+  this->Assert("multiply", left, right, expected);
+
+  left = {
+    decimal256(30, 3),
+    R"([
+      "123456789012345678901234567.890",
+      "0.000"
+    ])",
+  };
+  right = {
+    decimal256(20, 9),
+    R"([
+      "-12345678901.234567890",
+      "99999999999.999999999"
+    ])",
+  };
+  expected = R"([
+    "-1524157875323883675034293577501905199.875019052100",
+    "0.000000000000"
+  ])";
+  this->Assert("multiply", left, right, expected);
+  // clang-format on
+
+  // scalar array
+  left = {decimal128(3, 2), "3.14"};
+  right = {decimal128(1, 0), R"(["1", "2", "3", "4", "5"])"};
+  expected = R"(["3.14", "6.28", "9.42", "12.56", "15.70"])";
+  this->Assert("multiply", left, right, expected);
+
+  // scalar scalar
+  left = {decimal128(1, 0), "1"};
+  right = {decimal128(1, 0), "1"};
+  this->Assert("multiply", left, right, "1");
+
+  // failed case: result *maybe* overflow
+  left = {decimal128(20, 0), "1"};
+  right = {decimal128(18, 1), "1.0"};
+  this->AssertFail("multiply", left, right);
+}
+
+TEST_F(TestBinaryArithmeticDecimal, Divide) {
+  Arg left, right;
+  std::string expected;
+
+  // array array
+  // clang-format off
+  left = {decimal128(13, 3), R"(["1234567890.123", "0.001"])"};
+  right = {decimal128(3, 0), R"(["-987", "999"])"};
+  // scale = 7
+  expected = R"(["-1250828.6627386", "0.0000010"])";
+  this->Assert("divide", left, right, expected);
+
+  left = {decimal256(20, 10), R"(["1234567890.1234567890", 
"9999999999.9999999999"])"};
+  right = {decimal256(13, 3), R"(["1234567890.123", "0.001"])"};
+  // scale = 21
+  expected = R"(["1.000000000000369999093", 
"9999999999999.999999900000000000000"])";
+  this->Assert("divide", left, right, expected);
+  // clang-format on

Review comment:
       Please don't disable clang format so frequently




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to