This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new c9d4685c44 GH-37028: [C++] Add support for duration types to if_else
functions (#37064)
c9d4685c44 is described below
commit c9d4685c440d1fc0f2369654bc5076ea29dbb55c
Author: Jin Shang <[email protected]>
AuthorDate: Mon Aug 14 23:59:01 2023 +0800
GH-37028: [C++] Add support for duration types to if_else functions (#37064)
### Rationale for this change
Support for duration types is missing in if else functions, including
if_else, coalesce, choose and case_when.
### What changes are included in this PR?
Add support for duration types to these functions.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
No.
* Closes: #37028
Authored-by: Jin Shang <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
cpp/src/arrow/compute/kernels/codegen_internal.cc | 10 +++++++++-
cpp/src/arrow/compute/kernels/scalar_if_else.cc | 4 ++++
cpp/src/arrow/compute/kernels/scalar_if_else_test.cc | 15 +++++++++++++--
cpp/src/arrow/compute/kernels/test_util.h | 5 ++++-
4 files changed, 30 insertions(+), 4 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc
b/cpp/src/arrow/compute/kernels/codegen_internal.cc
index e0156caecf..8e2669bd3d 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -250,7 +250,7 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t
count) {
const std::string* timezone = nullptr;
bool saw_date32 = false;
bool saw_date64 = false;
-
+ bool saw_duration = false;
const TypeHolder* end = begin + count;
for (auto it = begin; it != end; it++) {
auto id = it->type->id();
@@ -271,6 +271,12 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t
count) {
finest_unit = std::max(finest_unit, ty.unit());
continue;
}
+ case Type::DURATION: {
+ const auto& ty = checked_cast<const DurationType&>(*it->type);
+ finest_unit = std::max(finest_unit, ty.unit());
+ saw_duration = true;
+ continue;
+ }
default:
return TypeHolder(nullptr);
}
@@ -283,6 +289,8 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t
count) {
return date64();
} else if (saw_date32) {
return date32();
+ } else if (saw_duration) {
+ return duration(finest_unit);
}
return TypeHolder(nullptr);
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index 0dd176b5d4..6b4b2339e4 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -2798,6 +2798,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveIfElseKernels(func, NumericTypes());
AddPrimitiveIfElseKernels(func, TemporalTypes());
AddPrimitiveIfElseKernels(func, IntervalTypes());
+ AddPrimitiveIfElseKernels(func, DurationTypes());
AddPrimitiveIfElseKernels(func, {boolean()});
AddNullIfElseKernel(func);
AddBinaryIfElseKernels(func, BaseBinaryTypes());
@@ -2813,6 +2814,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveCaseWhenKernels(func, NumericTypes());
AddPrimitiveCaseWhenKernels(func, TemporalTypes());
AddPrimitiveCaseWhenKernels(func, IntervalTypes());
+ AddPrimitiveCaseWhenKernels(func, DurationTypes());
AddPrimitiveCaseWhenKernels(func, {boolean(), null()});
AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY,
CaseWhenFunctor<FixedSizeBinaryType>::Exec);
@@ -2836,6 +2838,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveCoalesceKernels(func, NumericTypes());
AddPrimitiveCoalesceKernels(func, TemporalTypes());
AddPrimitiveCoalesceKernels(func, IntervalTypes());
+ AddPrimitiveCoalesceKernels(func, DurationTypes());
AddPrimitiveCoalesceKernels(func, {boolean(), null()});
AddCoalesceKernel(func, Type::FIXED_SIZE_BINARY,
CoalesceFunctor<FixedSizeBinaryType>::Exec);
@@ -2861,6 +2864,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveChooseKernels(func, NumericTypes());
AddPrimitiveChooseKernels(func, TemporalTypes());
AddPrimitiveChooseKernels(func, IntervalTypes());
+ AddPrimitiveChooseKernels(func, DurationTypes());
AddPrimitiveChooseKernels(func, {boolean(), null()});
AddChooseKernel(func, Type::FIXED_SIZE_BINARY,
ChooseFunctor<FixedSizeBinaryType>::Exec);
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index ded73f0371..a9c5a1fc3c 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -73,7 +73,7 @@ class TestIfElsePrimitive : public ::testing::Test {};
#ifdef ARROW_VALGRIND
using IfElseNumericBasedTypes =
::testing::Types<UInt32Type, FloatType, Date32Type, Time32Type,
TimestampType,
- MonthIntervalType>;
+ MonthIntervalType, DurationType>;
using BaseBinaryArrowTypes = ::testing::Types<BinaryType>;
using ListArrowTypes = ::testing::Types<ListType>;
using IntegralArrowTypes = ::testing::Types<Int32Type>;
@@ -81,7 +81,8 @@ using IntegralArrowTypes = ::testing::Types<Int32Type>;
using IfElseNumericBasedTypes =
::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type,
Int16Type,
Int32Type, Int64Type, FloatType, DoubleType, Date32Type,
Date64Type,
- Time32Type, Time64Type, TimestampType, MonthIntervalType>;
+ Time32Type, Time64Type, TimestampType, MonthIntervalType,
+ DurationType>;
#endif
TYPED_TEST_SUITE(TestIfElsePrimitive, IfElseNumericBasedTypes);
@@ -505,6 +506,9 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) {
{boolean(), timestamp(TimeUnit::MILLI),
timestamp(TimeUnit::MILLI)});
CheckDispatchBest(name, {boolean(), date32(), timestamp(TimeUnit::MILLI)},
{boolean(), timestamp(TimeUnit::MILLI),
timestamp(TimeUnit::MILLI)});
+ CheckDispatchBest(name,
+ {boolean(), duration(TimeUnit::SECOND),
duration(TimeUnit::MILLI)},
+ {boolean(), duration(TimeUnit::MILLI),
duration(TimeUnit::MILLI)});
CheckDispatchBest(name, {boolean(), date32(), date64()},
{boolean(), date64(), date64()});
CheckDispatchBest(name, {boolean(), date32(), date32()},
@@ -2500,6 +2504,11 @@ TEST(TestCaseWhen, DispatchBest) {
{struct_({field("", boolean())}), timestamp(TimeUnit::SECOND), date32()},
{struct_({field("", boolean())}), timestamp(TimeUnit::SECOND),
timestamp(TimeUnit::SECOND)});
+ CheckDispatchBest("case_when",
+ {struct_({field("", boolean())}),
duration(TimeUnit::SECOND),
+ duration(TimeUnit::MILLI)},
+ {struct_({field("", boolean())}),
duration(TimeUnit::MILLI),
+ duration(TimeUnit::MILLI)});
CheckDispatchBest(
"case_when", {struct_({field("", boolean())}), decimal128(38, 0),
decimal128(1, 1)},
{struct_({field("", boolean())}), decimal256(39, 1), decimal256(39, 1)});
@@ -3350,6 +3359,8 @@ TEST(TestCoalesce, DispatchBest) {
{timestamp(TimeUnit::SECOND),
timestamp(TimeUnit::SECOND)});
CheckDispatchBest("coalesce", {timestamp(TimeUnit::SECOND),
timestamp(TimeUnit::MILLI)},
{timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
+ CheckDispatchBest("coalesce", {duration(TimeUnit::SECOND),
duration(TimeUnit::MILLI)},
+ {duration(TimeUnit::MILLI), duration(TimeUnit::MILLI)});
CheckDispatchFails("coalesce", {
sparse_union({field("a", boolean())}),
dense_union({field("a", boolean())}),
diff --git a/cpp/src/arrow/compute/kernels/test_util.h
b/cpp/src/arrow/compute/kernels/test_util.h
index 73762a1ac6..11e77caeff 100644
--- a/cpp/src/arrow/compute/kernels/test_util.h
+++ b/cpp/src/arrow/compute/kernels/test_util.h
@@ -185,7 +185,10 @@ template <typename T>
enable_if_decimal<T, std::shared_ptr<DataType>> default_type_instance() {
return std::make_shared<T>(5, 2);
}
-
+template <typename T>
+enable_if_duration<T, std::shared_ptr<DataType>> default_type_instance() {
+ return std::make_shared<T>(TimeUnit::type::SECOND);
+}
// Random Generator Helpers
class RandomImpl {
protected: