cyb70289 commented on a change in pull request #10364:
URL: https://github.com/apache/arrow/pull/10364#discussion_r645468811



##########
File path: cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
##########
@@ -1161,5 +1161,312 @@ TYPED_TEST(TestUnaryArithmeticFloating, AbsoluteValue) {
   }
 }
 
+class TestBinaryArithmeticDecimal : public TestBase {
+ protected:
+  struct Arg {
+    std::shared_ptr<DataType> type;
+    std::string value;
+  };
+
+  std::shared_ptr<DataType> GetOutType(const std::string& op,
+                                       const std::shared_ptr<DataType>& 
left_type,
+                                       const std::shared_ptr<DataType>& 
right_type) {
+    auto left_decimal_type = std::static_pointer_cast<DecimalType>(left_type);
+    auto right_decimal_type = 
std::static_pointer_cast<DecimalType>(right_type);
+
+    const int32_t p1 = left_decimal_type->precision(), s1 = 
left_decimal_type->scale();
+    const int32_t p2 = right_decimal_type->precision(), s2 = 
right_decimal_type->scale();
+
+    // 
https://docs.aws.amazon.com/redshift/latest/dg/r_numeric_computations201.html
+    int32_t precision, scale;
+    if (op == "add" || op == "subtract") {
+      scale = std::max(s1, s2);
+      precision = std::max(p1 - s1, p2 - s2) + 1 + scale;
+    } else if (op == "multiply") {
+      scale = s1 + s2;
+      precision = p1 + p2 + 1;
+    } else if (op == "divide") {
+      scale = std::max(4, s1 + p2 - s2 + 1);
+      precision = p1 - s1 + s2 + scale;
+    } else {
+      ABORT_NOT_OK(Status::Invalid("invalid binary operator: ", op));
+    }
+
+    std::shared_ptr<DataType> type;
+    if (left_type->id() == Type::DECIMAL128) {
+      ASSIGN_OR_ABORT(type, Decimal128Type::Make(precision, scale));
+    } else {
+      ASSIGN_OR_ABORT(type, Decimal256Type::Make(precision, scale));
+    }
+    return type;
+  }
+
+  std::shared_ptr<Scalar> MakeScalar(const std::shared_ptr<DataType>& type,
+                                     const std::string& str) {
+    std::shared_ptr<Scalar> scalar;
+    if (type->id() == Type::DECIMAL128) {
+      Decimal128 value;
+      int32_t dummy;
+      ABORT_NOT_OK(Decimal128::FromString(str, &value, &dummy));
+      ASSIGN_OR_ABORT(scalar, arrow::MakeScalar(type, value));
+    } else {
+      Decimal256 value;
+      int32_t dummy;
+      ABORT_NOT_OK(Decimal256::FromString(str, &value, &dummy));
+      ASSIGN_OR_ABORT(scalar, arrow::MakeScalar(type, value));
+    }
+    return scalar;
+  }
+
+  Datum ToDatum(const std::shared_ptr<DataType>& type, const std::string& 
value) {
+    if (value.find("[") == std::string::npos) {
+      return Datum(MakeScalar(type, value));
+    } else {
+      return Datum(ArrayFromJSON(type, value));
+    }
+  }
+
+  void Assert(const std::string& op, const Arg& left, const Arg& right,
+              const std::string& expected) {
+    const Datum arg0 = ToDatum(left.type, left.value);
+    const Datum arg1 = ToDatum(right.type, right.value);
+
+    auto out_type = GetOutType(op, left.type, right.type);
+    CheckScalarGeneral(op, {arg0, arg1}, ToDatum(out_type, expected), 
&options_);
+
+    // commutative operations
+    if (op == "add" || op == "multiply") {
+      CheckScalarGeneral(op, {arg1, arg0}, ToDatum(out_type, expected), 
&options_);
+    }
+  }
+
+  void AssertFail(const std::string& op, const Arg& left, const Arg& right) {
+    const Datum arg0 = ToDatum(left.type, left.value);
+    const Datum arg1 = ToDatum(right.type, right.value);
+
+    ASSERT_NOT_OK(CallFunction(op, {arg0, arg1}, &options_));
+    if (op == "add" || op == "multiply") {
+      ASSERT_NOT_OK(CallFunction(op, {arg1, arg0}, &options_));
+    }
+  }
+
+  ArithmeticOptions options_ = ArithmeticOptions();
+};
+
+// reference result from bc (precsion=100, scale=40)
+TEST_F(TestBinaryArithmeticDecimal, AddSubtract) {
+  Arg left, right;
+  std::string added, subtracted;
+
+  // array array, decimal128
+  // clang-format off
+  left = {
+    decimal128(30, 3),
+    R"([
+      "1.000",
+      "-123456789012345678901234567.890",
+      "98765432109876543210.987",
+      "-999999999999999999999999999.999"
+    ])",
+  };
+  right = {
+    decimal128(20, 9),
+    R"([
+      "-1.000000000",
+      "12345678901.234567890",
+      "98765.432101234",
+      "-99999999999.999999999"
+    ])",
+  };
+  added = R"([
+    "0.000000000",
+    "-123456789012345666555555666.655432110",
+    "98765432109876641976.419101234",
+    "-1000000000000000099999999999.998999999"
+  ])";
+  subtracted = R"([
+    "2.000000000",
+    "-123456789012345691246913469.124567890",
+    "98765432109876444445.554898766",
+    "-999999999999999899999999999.999000001"
+  ])";
+  this->Assert("add", left, right, added);
+  this->Assert("subtract", left, right, subtracted);
+
+  // array array, decimal256
+  left = {
+    decimal256(30, 20),
+    R"([
+      "-1.00000000000000000001",
+      "1234567890.12345678900000000000",
+      "-9876543210.09876543210987654321",
+      "9999999999.99999999999999999999"
+    ])",
+  };
+  right = {
+    decimal256(30, 10),
+    R"([
+      "1.0000000000",
+      "-1234567890.1234567890",
+      "6789.5432101234",
+      "99999999999999999999.9999999999"
+    ])",
+  };
+  added = R"([
+    "-0.00000000000000000001",
+    "0.00000000000000000000",
+    "-9876536420.55555530870987654321",
+    "100000000009999999999.99999999989999999999"
+  ])";
+  subtracted = R"([
+    "-2.00000000000000000001",
+    "2469135780.24691357800000000000",
+    "-9876549999.64197555550987654321",
+    "-99999999989999999999.99999999990000000001"
+  ])";
+  this->Assert("add", left, right, added);
+  this->Assert("subtract", left, right, subtracted);
+  // clang-format on
+
+  // scalar array
+  left = {decimal128(6, 1), "12345.6"};
+  right = {decimal128(10, 3), R"(["1.234", "1234.000", "-9876.543", 
"666.888"])"};
+  added = R"(["12346.834", "13579.600", "2469.057", "13012.488"])";
+  subtracted = R"(["12344.366", "11111.600", "22222.143", "11678.712"])";
+  this->Assert("add", left, right, added);
+  this->Assert("subtract", left, right, subtracted);
+  // right - left
+  subtracted = R"(["-12344.366", "-11111.600", "-22222.143", "-11678.712"])";
+  this->Assert("subtract", right, left, subtracted);
+
+  // scalar scalar
+  left = {decimal256(3, 0), "666"};
+  right = {decimal256(3, 0), "888"};
+  this->Assert("add", left, right, "1554");
+  this->Assert("subtract", left, right, "-222");
+
+  // failed case: result *maybe* overflow
+  left = {decimal128(21, 20), "0.12345678901234567890"};
+  right = {decimal128(21, 1), "1.0"};
+  this->AssertFail("add", left, right);
+  this->AssertFail("subtract", left, right);
+
+  left = {decimal256(75, 0), "0"};
+  right = {decimal256(2, 1), "0.0"};
+  this->AssertFail("add", left, right);
+  this->AssertFail("subtract", left, right);
+}
+
+TEST_F(TestBinaryArithmeticDecimal, Multiply) {
+  Arg left, right;
+  std::string expected;
+
+  // array array
+  // clang-format off
+  left = {
+    decimal128(20, 10),
+    R"([
+      "1234567890.1234567890",
+      "-0.0000000001",
+      "-9999999999.9999999999"
+    ])",
+  };
+  right = {
+    decimal128(13, 3),
+    R"([
+      "1234567890.123",
+      "0.001",
+      "-9999999999.999"
+    ])",
+  };
+  expected = R"([
+    "1524157875323319737.9870903950470",
+    "-0.0000000000001",
+    "99999999999989999999.0000000000001"
+  ])";
+  this->Assert("multiply", left, right, expected);
+
+  left = {
+    decimal256(30, 3),
+    R"([
+      "123456789012345678901234567.890",
+      "0.000"
+    ])",
+  };
+  right = {
+    decimal256(20, 9),
+    R"([
+      "-12345678901.234567890",
+      "99999999999.999999999"
+    ])",
+  };
+  expected = R"([
+    "-1524157875323883675034293577501905199.875019052100",
+    "0.000000000000"
+  ])";
+  this->Assert("multiply", left, right, expected);
+  // clang-format on
+
+  // scalar array
+  left = {decimal128(3, 2), "3.14"};
+  right = {decimal128(1, 0), R"(["1", "2", "3", "4", "5"])"};
+  expected = R"(["3.14", "6.28", "9.42", "12.56", "15.70"])";
+  this->Assert("multiply", left, right, expected);
+
+  // scalar scalar
+  left = {decimal128(1, 0), "1"};
+  right = {decimal128(1, 0), "1"};
+  this->Assert("multiply", left, right, "1");
+
+  // failed case: result *maybe* overflow
+  left = {decimal128(20, 0), "1"};
+  right = {decimal128(18, 1), "1.0"};
+  this->AssertFail("multiply", left, right);
+}
+
+TEST_F(TestBinaryArithmeticDecimal, Divide) {
+  Arg left, right;
+  std::string expected;
+
+  // array array
+  // clang-format off
+  left = {decimal128(13, 3), R"(["1234567890.123", "0.001"])"};
+  right = {decimal128(3, 0), R"(["-987", "999"])"};
+  // scale = 7
+  expected = R"(["-1250828.6627386", "0.0000010"])";
+  this->Assert("divide", left, right, expected);
+
+  left = {decimal256(20, 10), R"(["1234567890.1234567890", 
"9999999999.9999999999"])"};
+  right = {decimal256(13, 3), R"(["1234567890.123", "0.001"])"};
+  // scale = 21
+  expected = R"(["1.000000000000369999093", 
"9999999999999.999999900000000000000"])";
+  this->Assert("divide", left, right, expected);
+  // clang-format on

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to