This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 5709c0dbb5 GH-39233: [Compute] Add some duration kernels (#39358)
5709c0dbb5 is described below
commit 5709c0dbb52e2075bd89a4fde69030c3eac385dd
Author: Jin Shang <[email protected]>
AuthorDate: Thu Jan 11 23:38:41 2024 +0800
GH-39233: [Compute] Add some duration kernels (#39358)
### Rationale for this change
Add kernels for durations.
### What changes are included in this PR?
In this PR I added the ones that require only registration and unit tests.
More complicated ones will be in another PR for readability.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
No.
* Closes: #39233
Authored-by: Jin Shang <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
cpp/src/arrow/compute/kernels/scalar_arithmetic.cc | 35 ++++++++++
cpp/src/arrow/compute/kernels/scalar_compare.cc | 9 +++
.../arrow/compute/kernels/scalar_compare_test.cc | 7 +-
.../arrow/compute/kernels/scalar_temporal_test.cc | 12 ++++
cpp/src/arrow/compute/kernels/scalar_validity.cc | 6 +-
.../arrow/compute/kernels/scalar_validity_test.cc | 7 ++
docs/source/cpp/compute.rst | 78 +++++++++++-----------
7 files changed, 113 insertions(+), 41 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
index ad33d7f895..44f5fea790 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
@@ -1286,12 +1286,27 @@ void RegisterScalarArithmetic(FunctionRegistry*
registry) {
auto absolute_value =
MakeUnaryArithmeticFunction<AbsoluteValue>("abs", absolute_value_doc);
AddDecimalUnaryKernels<AbsoluteValue>(absolute_value.get());
+
+ // abs(duration)
+ for (auto unit : TimeUnit::values()) {
+ auto exec = ArithmeticExecFromOp<ScalarUnary,
AbsoluteValue>(duration(unit));
+ DCHECK_OK(
+ absolute_value->AddKernel({duration(unit)},
OutputType(duration(unit)), exec));
+ }
+
DCHECK_OK(registry->AddFunction(std::move(absolute_value)));
// ----------------------------------------------------------------------
auto absolute_value_checked =
MakeUnaryArithmeticFunctionNotNull<AbsoluteValueChecked>(
"abs_checked", absolute_value_checked_doc);
AddDecimalUnaryKernels<AbsoluteValueChecked>(absolute_value_checked.get());
+ // abs_checked(duraton)
+ for (auto unit : TimeUnit::values()) {
+ auto exec =
+ ArithmeticExecFromOp<ScalarUnaryNotNull,
AbsoluteValueChecked>(duration(unit));
+ DCHECK_OK(absolute_value_checked->AddKernel({duration(unit)},
+ OutputType(duration(unit)),
exec));
+ }
DCHECK_OK(registry->AddFunction(std::move(absolute_value_checked)));
// ----------------------------------------------------------------------
@@ -1545,12 +1560,27 @@ void RegisterScalarArithmetic(FunctionRegistry*
registry) {
// ----------------------------------------------------------------------
auto negate = MakeUnaryArithmeticFunction<Negate>("negate", negate_doc);
AddDecimalUnaryKernels<Negate>(negate.get());
+
+ // Add neg(duration) -> duration
+ for (auto unit : TimeUnit::values()) {
+ auto exec = ArithmeticExecFromOp<ScalarUnary, Negate>(duration(unit));
+ DCHECK_OK(negate->AddKernel({duration(unit)}, OutputType(duration(unit)),
exec));
+ }
+
DCHECK_OK(registry->AddFunction(std::move(negate)));
// ----------------------------------------------------------------------
auto negate_checked =
MakeUnarySignedArithmeticFunctionNotNull<NegateChecked>(
"negate_checked", negate_checked_doc);
AddDecimalUnaryKernels<NegateChecked>(negate_checked.get());
+
+ // Add neg_checked(duration) -> duration
+ for (auto unit : TimeUnit::values()) {
+ auto exec = ArithmeticExecFromOp<ScalarUnaryNotNull,
Negate>(duration(unit));
+ DCHECK_OK(
+ negate_checked->AddKernel({duration(unit)},
OutputType(duration(unit)), exec));
+ }
+
DCHECK_OK(registry->AddFunction(std::move(negate_checked)));
// ----------------------------------------------------------------------
@@ -1581,6 +1611,11 @@ void RegisterScalarArithmetic(FunctionRegistry*
registry) {
// ----------------------------------------------------------------------
auto sign =
MakeUnaryArithmeticFunctionWithFixedIntOutType<Sign, Int8Type>("sign",
sign_doc);
+ // sign(duration)
+ for (auto unit : TimeUnit::values()) {
+ auto exec = ScalarUnary<Int8Type, Int64Type, Sign>::Exec;
+ DCHECK_OK(sign->AddKernel({duration(unit)}, int8(), std::move(exec)));
+ }
DCHECK_OK(registry->AddFunction(std::move(sign)));
// ----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc
b/cpp/src/arrow/compute/kernels/scalar_compare.cc
index aad648ca27..daf8ed76d6 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc
@@ -22,6 +22,7 @@
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/kernels/common_internal.h"
+#include "arrow/type.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_ops.h"
@@ -806,6 +807,14 @@ std::shared_ptr<ScalarFunction>
MakeScalarMinMax(std::string name, FunctionDoc d
kernel.mem_allocation = MemAllocation::type::PREALLOCATE;
DCHECK_OK(func->AddKernel(std::move(kernel)));
}
+ for (const auto& ty : DurationTypes()) {
+ auto exec = GeneratePhysicalNumeric<ScalarMinMax, Op>(ty);
+ ScalarKernel kernel{KernelSignature::Make({ty}, ty, /*is_varargs=*/true),
exec,
+ MinMaxState::Init};
+ kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::type::PREALLOCATE;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ }
for (const auto& ty : BaseBinaryTypes()) {
auto exec =
GenerateTypeAgnosticVarBinaryBase<BinaryScalarMinMax, ArrayKernelExec,
Op>(ty);
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
index 48fa780b03..8f5952b405 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc
@@ -1281,7 +1281,7 @@ using CompareNumericBasedTypes =
::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type,
Int16Type,
Int32Type, Int64Type, FloatType, DoubleType, Date32Type,
Date64Type>;
using CompareParametricTemporalTypes =
- ::testing::Types<TimestampType, Time32Type, Time64Type>;
+ ::testing::Types<TimestampType, Time32Type, Time64Type, DurationType>;
using CompareFixedSizeBinaryTypes = ::testing::Types<FixedSizeBinaryType>;
TYPED_TEST_SUITE(TestVarArgsCompareNumeric, CompareNumericBasedTypes);
@@ -2121,6 +2121,11 @@ TEST(TestMaxElementWiseMinElementWise, CommonTemporal) {
ScalarFromJSON(date64(), "172800000"),
}),
ResultWith(ScalarFromJSON(date64(), "86400000")));
+ EXPECT_THAT(MinElementWise({
+ ScalarFromJSON(duration(TimeUnit::SECOND), "1"),
+ ScalarFromJSON(duration(TimeUnit::MILLI), "12000"),
+ }),
+ ResultWith(ScalarFromJSON(duration(TimeUnit::MILLI), "1000")));
}
} // namespace compute
diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
index d448233428..8dac6525fe 100644
--- a/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_temporal_test.cc
@@ -3665,5 +3665,17 @@ TEST_F(ScalarTemporalTest,
TestCeilFloorRoundTemporalDate) {
CheckScalarUnary("ceil_temporal", arr_ns, arr_ns, &round_to_2_hours);
}
+TEST_F(ScalarTemporalTest, DurationUnaryArithmetics) {
+ auto arr = ArrayFromJSON(duration(TimeUnit::SECOND), "[2, -1, null, 3, 0]");
+ CheckScalarUnary("negate", arr,
+ ArrayFromJSON(duration(TimeUnit::SECOND), "[-2, 1, null,
-3, 0]"));
+ CheckScalarUnary("negate_checked", arr,
+ ArrayFromJSON(duration(TimeUnit::SECOND), "[-2, 1, null,
-3, 0]"));
+ CheckScalarUnary("abs", arr,
+ ArrayFromJSON(duration(TimeUnit::SECOND), "[2, 1, null, 3,
0]"));
+ CheckScalarUnary("abs_checked", arr,
+ ArrayFromJSON(duration(TimeUnit::SECOND), "[2, 1, null, 3,
0]"));
+ CheckScalarUnary("sign", arr, ArrayFromJSON(int8(), "[1, -1, null, 1, 0]"));
+}
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/scalar_validity.cc
b/cpp/src/arrow/compute/kernels/scalar_validity.cc
index 6b1cec0f5c..8505fc4c6e 100644
--- a/cpp/src/arrow/compute/kernels/scalar_validity.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_validity.cc
@@ -169,6 +169,7 @@ std::shared_ptr<ScalarFunction>
MakeIsFiniteFunction(std::string name, FunctionD
func->AddKernel({InputType(Type::DECIMAL128)}, boolean(),
ConstBoolExec<true>));
DCHECK_OK(
func->AddKernel({InputType(Type::DECIMAL256)}, boolean(),
ConstBoolExec<true>));
+ DCHECK_OK(func->AddKernel({InputType(Type::DURATION)}, boolean(),
ConstBoolExec<true>));
return func;
}
@@ -187,7 +188,8 @@ std::shared_ptr<ScalarFunction>
MakeIsInfFunction(std::string name, FunctionDoc
func->AddKernel({InputType(Type::DECIMAL128)}, boolean(),
ConstBoolExec<false>));
DCHECK_OK(
func->AddKernel({InputType(Type::DECIMAL256)}, boolean(),
ConstBoolExec<false>));
-
+ DCHECK_OK(
+ func->AddKernel({InputType(Type::DURATION)}, boolean(),
ConstBoolExec<false>));
return func;
}
@@ -205,6 +207,8 @@ std::shared_ptr<ScalarFunction>
MakeIsNanFunction(std::string name, FunctionDoc
func->AddKernel({InputType(Type::DECIMAL128)}, boolean(),
ConstBoolExec<false>));
DCHECK_OK(
func->AddKernel({InputType(Type::DECIMAL256)}, boolean(),
ConstBoolExec<false>));
+ DCHECK_OK(
+ func->AddKernel({InputType(Type::DURATION)}, boolean(),
ConstBoolExec<false>));
return func;
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_validity_test.cc
b/cpp/src/arrow/compute/kernels/scalar_validity_test.cc
index 94d951c838..d1462838f3 100644
--- a/cpp/src/arrow/compute/kernels/scalar_validity_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_validity_test.cc
@@ -103,6 +103,9 @@ TEST(TestValidityKernels, IsFinite) {
}
CheckScalar("is_finite", {std::make_shared<NullArray>(4)},
ArrayFromJSON(boolean(), "[null, null, null, null]"));
+ CheckScalar("is_finite",
+ {ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 42, null]")},
+ ArrayFromJSON(boolean(), "[true, true, true, null]"));
}
TEST(TestValidityKernels, IsInf) {
@@ -116,6 +119,8 @@ TEST(TestValidityKernels, IsInf) {
}
CheckScalar("is_inf", {std::make_shared<NullArray>(4)},
ArrayFromJSON(boolean(), "[null, null, null, null]"));
+ CheckScalar("is_inf", {ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 42,
null]")},
+ ArrayFromJSON(boolean(), "[false, false, false, null]"));
}
TEST(TestValidityKernels, IsNan) {
@@ -129,6 +134,8 @@ TEST(TestValidityKernels, IsNan) {
}
CheckScalar("is_nan", {std::make_shared<NullArray>(4)},
ArrayFromJSON(boolean(), "[null, null, null, null]"));
+ CheckScalar("is_nan", {ArrayFromJSON(duration(TimeUnit::SECOND), "[0, 1, 42,
null]")},
+ ArrayFromJSON(boolean(), "[false, false, false, null]"));
}
TEST(TestValidityKernels, IsValidIsNullNullType) {
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 17d003b261..e7310d2c0c 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -458,45 +458,45 @@ floating-point arguments will cast all arguments to
floating-point, while mixed
decimal and integer arguments will cast all arguments to decimals.
Mixed time resolution temporal inputs will be cast to finest input resolution.
-+------------------+--------+-------------------------+----------------------+-------+
-| Function name | Arity | Input types | Output type |
Notes |
-+==================+========+=========================+======================+=======+
-| abs | Unary | Numeric | Numeric |
|
-+------------------+--------+-------------------------+----------------------+-------+
-| abs_checked | Unary | Numeric | Numeric |
|
-+------------------+--------+-------------------------+----------------------+-------+
-| add | Binary | Numeric/Temporal | Numeric/Temporal |
\(1) |
-+------------------+--------+-------------------------+----------------------+-------+
-| add_checked | Binary | Numeric/Temporal | Numeric/Temporal |
\(1) |
-+------------------+--------+-------------------------+----------------------+-------+
-| divide | Binary | Numeric/Temporal | Numeric/Temporal |
\(1) |
-+------------------+--------+-------------------------+----------------------+-------+
-| divide_checked | Binary | Numeric/Temporal | Numeric/Temporal |
\(1) |
-+------------------+--------+-------------------------+----------------------+-------+
-| exp | Unary | Numeric | Float32/Float64 |
|
-+------------------+--------+-------------------------+----------------------+-------+
-| multiply | Binary | Numeric/Temporal | Numeric/Temporal |
\(1) |
-+------------------+--------+-------------------------+----------------------+-------+
-| multiply_checked | Binary | Numeric/Temporal | Numeric/Temporal |
\(1) |
-+------------------+--------+-------------------------+----------------------+-------+
-| negate | Unary | Numeric | Numeric |
|
-+------------------+--------+-------------------------+----------------------+-------+
-| negate_checked | Unary | Signed Numeric | Signed Numeric |
|
-+------------------+--------+-------------------------+----------------------+-------+
-| power | Binary | Numeric | Numeric |
|
-+------------------+--------+-------------------------+----------------------+-------+
-| power_checked | Binary | Numeric | Numeric |
|
-+------------------+--------+-------------------------+----------------------+-------+
-| sign | Unary | Numeric | Int8/Float32/Float64 |
\(2) |
-+------------------+--------+-------------------------+----------------------+-------+
-| sqrt | Unary | Numeric | Numeric |
|
-+------------------+--------+-------------------------+----------------------+-------+
-| sqrt_checked | Unary | Numeric | Numeric |
|
-+------------------+--------+-------------------------+----------------------+-------+
-| subtract | Binary | Numeric/Temporal | Numeric/Temporal |
\(1) |
-+------------------+--------+-------------------------+----------------------+-------+
-| subtract_checked | Binary | Numeric/Temporal | Numeric/Temporal |
\(1) |
-+------------------+--------+-------------------------+----------------------+-------+
++------------------+--------+-------------------------+---------------------------+-------+
+| Function name | Arity | Input types | Output type
| Notes |
++==================+========+=========================+===========================+=======+
+| abs | Unary | Numeric/Duration | Numeric/Duration
| |
++------------------+--------+-------------------------+---------------------------+-------+
+| abs_checked | Unary | Numeric/Duration | Numeric/Duration
| |
++------------------+--------+-------------------------+---------------------------+-------+
+| add | Binary | Numeric/Temporal | Numeric/Temporal
| \(1) |
++------------------+--------+-------------------------+---------------------------+-------+
+| add_checked | Binary | Numeric/Temporal | Numeric/Temporal
| \(1) |
++------------------+--------+-------------------------+---------------------------+-------+
+| divide | Binary | Numeric/Temporal | Numeric/Temporal
| \(1) |
++------------------+--------+-------------------------+---------------------------+-------+
+| divide_checked | Binary | Numeric/Temporal | Numeric/Temporal
| \(1) |
++------------------+--------+-------------------------+---------------------------+-------+
+| exp | Unary | Numeric | Float32/Float64
| |
++------------------+--------+-------------------------+---------------------------+-------+
+| multiply | Binary | Numeric/Temporal | Numeric/Temporal
| \(1) |
++------------------+--------+-------------------------+---------------------------+-------+
+| multiply_checked | Binary | Numeric/Temporal | Numeric/Temporal
| \(1) |
++------------------+--------+-------------------------+---------------------------+-------+
+| negate | Unary | Numeric/Duration | Numeric/Duration
| |
++------------------+--------+-------------------------+---------------------------+-------+
+| negate_checked | Unary | Signed Numeric/Duration | Signed
Numeric/Duration | |
++------------------+--------+-------------------------+---------------------------+-------+
+| power | Binary | Numeric | Numeric
| |
++------------------+--------+-------------------------+---------------------------+-------+
+| power_checked | Binary | Numeric | Numeric
| |
++------------------+--------+-------------------------+---------------------------+-------+
+| sign | Unary | Numeric/Duration | Int8/Float32/Float64
| \(2) |
++------------------+--------+-------------------------+---------------------------+-------+
+| sqrt | Unary | Numeric | Numeric
| |
++------------------+--------+-------------------------+---------------------------+-------+
+| sqrt_checked | Unary | Numeric | Numeric
| |
++------------------+--------+-------------------------+---------------------------+-------+
+| subtract | Binary | Numeric/Temporal | Numeric/Temporal
| \(1) |
++------------------+--------+-------------------------+---------------------------+-------+
+| subtract_checked | Binary | Numeric/Temporal | Numeric/Temporal
| \(1) |
++------------------+--------+-------------------------+---------------------------+-------+
* \(1) Precision and scale of computed DECIMAL results