This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new c9d4685c44 GH-37028: [C++] Add support for duration types to if_else 
functions (#37064)
c9d4685c44 is described below

commit c9d4685c440d1fc0f2369654bc5076ea29dbb55c
Author: Jin Shang <[email protected]>
AuthorDate: Mon Aug 14 23:59:01 2023 +0800

    GH-37028: [C++] Add support for duration types to if_else functions (#37064)
    
    
    
    ### Rationale for this change
    
    Support for duration types is missing in if else functions, including 
if_else, coalesce, choose and case_when.
    
    ### What changes are included in this PR?
    Add support for duration types to these functions.
    
    ### Are these changes tested?
    Yes.
    
    ### Are there any user-facing changes?
    
    No.
    
    * Closes: #37028
    
    Authored-by: Jin Shang <[email protected]>
    Signed-off-by: Antoine Pitrou <[email protected]>
---
 cpp/src/arrow/compute/kernels/codegen_internal.cc    | 10 +++++++++-
 cpp/src/arrow/compute/kernels/scalar_if_else.cc      |  4 ++++
 cpp/src/arrow/compute/kernels/scalar_if_else_test.cc | 15 +++++++++++++--
 cpp/src/arrow/compute/kernels/test_util.h            |  5 ++++-
 4 files changed, 30 insertions(+), 4 deletions(-)

diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc 
b/cpp/src/arrow/compute/kernels/codegen_internal.cc
index e0156caecf..8e2669bd3d 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -250,7 +250,7 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t 
count) {
   const std::string* timezone = nullptr;
   bool saw_date32 = false;
   bool saw_date64 = false;
-
+  bool saw_duration = false;
   const TypeHolder* end = begin + count;
   for (auto it = begin; it != end; it++) {
     auto id = it->type->id();
@@ -271,6 +271,12 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t 
count) {
         finest_unit = std::max(finest_unit, ty.unit());
         continue;
       }
+      case Type::DURATION: {
+        const auto& ty = checked_cast<const DurationType&>(*it->type);
+        finest_unit = std::max(finest_unit, ty.unit());
+        saw_duration = true;
+        continue;
+      }
       default:
         return TypeHolder(nullptr);
     }
@@ -283,6 +289,8 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t 
count) {
     return date64();
   } else if (saw_date32) {
     return date32();
+  } else if (saw_duration) {
+    return duration(finest_unit);
   }
   return TypeHolder(nullptr);
 }
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc 
b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index 0dd176b5d4..6b4b2339e4 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -2798,6 +2798,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
     AddPrimitiveIfElseKernels(func, NumericTypes());
     AddPrimitiveIfElseKernels(func, TemporalTypes());
     AddPrimitiveIfElseKernels(func, IntervalTypes());
+    AddPrimitiveIfElseKernels(func, DurationTypes());
     AddPrimitiveIfElseKernels(func, {boolean()});
     AddNullIfElseKernel(func);
     AddBinaryIfElseKernels(func, BaseBinaryTypes());
@@ -2813,6 +2814,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
     AddPrimitiveCaseWhenKernels(func, NumericTypes());
     AddPrimitiveCaseWhenKernels(func, TemporalTypes());
     AddPrimitiveCaseWhenKernels(func, IntervalTypes());
+    AddPrimitiveCaseWhenKernels(func, DurationTypes());
     AddPrimitiveCaseWhenKernels(func, {boolean(), null()});
     AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY,
                       CaseWhenFunctor<FixedSizeBinaryType>::Exec);
@@ -2836,6 +2838,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
     AddPrimitiveCoalesceKernels(func, NumericTypes());
     AddPrimitiveCoalesceKernels(func, TemporalTypes());
     AddPrimitiveCoalesceKernels(func, IntervalTypes());
+    AddPrimitiveCoalesceKernels(func, DurationTypes());
     AddPrimitiveCoalesceKernels(func, {boolean(), null()});
     AddCoalesceKernel(func, Type::FIXED_SIZE_BINARY,
                       CoalesceFunctor<FixedSizeBinaryType>::Exec);
@@ -2861,6 +2864,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
     AddPrimitiveChooseKernels(func, NumericTypes());
     AddPrimitiveChooseKernels(func, TemporalTypes());
     AddPrimitiveChooseKernels(func, IntervalTypes());
+    AddPrimitiveChooseKernels(func, DurationTypes());
     AddPrimitiveChooseKernels(func, {boolean(), null()});
     AddChooseKernel(func, Type::FIXED_SIZE_BINARY,
                     ChooseFunctor<FixedSizeBinaryType>::Exec);
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc 
b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index ded73f0371..a9c5a1fc3c 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -73,7 +73,7 @@ class TestIfElsePrimitive : public ::testing::Test {};
 #ifdef ARROW_VALGRIND
 using IfElseNumericBasedTypes =
     ::testing::Types<UInt32Type, FloatType, Date32Type, Time32Type, 
TimestampType,
-                     MonthIntervalType>;
+                     MonthIntervalType, DurationType>;
 using BaseBinaryArrowTypes = ::testing::Types<BinaryType>;
 using ListArrowTypes = ::testing::Types<ListType>;
 using IntegralArrowTypes = ::testing::Types<Int32Type>;
@@ -81,7 +81,8 @@ using IntegralArrowTypes = ::testing::Types<Int32Type>;
 using IfElseNumericBasedTypes =
     ::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, 
Int16Type,
                      Int32Type, Int64Type, FloatType, DoubleType, Date32Type, 
Date64Type,
-                     Time32Type, Time64Type, TimestampType, MonthIntervalType>;
+                     Time32Type, Time64Type, TimestampType, MonthIntervalType,
+                     DurationType>;
 #endif
 
 TYPED_TEST_SUITE(TestIfElsePrimitive, IfElseNumericBasedTypes);
@@ -505,6 +506,9 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) {
                     {boolean(), timestamp(TimeUnit::MILLI), 
timestamp(TimeUnit::MILLI)});
   CheckDispatchBest(name, {boolean(), date32(), timestamp(TimeUnit::MILLI)},
                     {boolean(), timestamp(TimeUnit::MILLI), 
timestamp(TimeUnit::MILLI)});
+  CheckDispatchBest(name,
+                    {boolean(), duration(TimeUnit::SECOND), 
duration(TimeUnit::MILLI)},
+                    {boolean(), duration(TimeUnit::MILLI), 
duration(TimeUnit::MILLI)});
   CheckDispatchBest(name, {boolean(), date32(), date64()},
                     {boolean(), date64(), date64()});
   CheckDispatchBest(name, {boolean(), date32(), date32()},
@@ -2500,6 +2504,11 @@ TEST(TestCaseWhen, DispatchBest) {
       {struct_({field("", boolean())}), timestamp(TimeUnit::SECOND), date32()},
       {struct_({field("", boolean())}), timestamp(TimeUnit::SECOND),
        timestamp(TimeUnit::SECOND)});
+  CheckDispatchBest("case_when",
+                    {struct_({field("", boolean())}), 
duration(TimeUnit::SECOND),
+                     duration(TimeUnit::MILLI)},
+                    {struct_({field("", boolean())}), 
duration(TimeUnit::MILLI),
+                     duration(TimeUnit::MILLI)});
   CheckDispatchBest(
       "case_when", {struct_({field("", boolean())}), decimal128(38, 0), 
decimal128(1, 1)},
       {struct_({field("", boolean())}), decimal256(39, 1), decimal256(39, 1)});
@@ -3350,6 +3359,8 @@ TEST(TestCoalesce, DispatchBest) {
                     {timestamp(TimeUnit::SECOND), 
timestamp(TimeUnit::SECOND)});
   CheckDispatchBest("coalesce", {timestamp(TimeUnit::SECOND), 
timestamp(TimeUnit::MILLI)},
                     {timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
+  CheckDispatchBest("coalesce", {duration(TimeUnit::SECOND), 
duration(TimeUnit::MILLI)},
+                    {duration(TimeUnit::MILLI), duration(TimeUnit::MILLI)});
   CheckDispatchFails("coalesce", {
                                      sparse_union({field("a", boolean())}),
                                      dense_union({field("a", boolean())}),
diff --git a/cpp/src/arrow/compute/kernels/test_util.h 
b/cpp/src/arrow/compute/kernels/test_util.h
index 73762a1ac6..11e77caeff 100644
--- a/cpp/src/arrow/compute/kernels/test_util.h
+++ b/cpp/src/arrow/compute/kernels/test_util.h
@@ -185,7 +185,10 @@ template <typename T>
 enable_if_decimal<T, std::shared_ptr<DataType>> default_type_instance() {
   return std::make_shared<T>(5, 2);
 }
-
+template <typename T>
+enable_if_duration<T, std::shared_ptr<DataType>> default_type_instance() {
+  return std::make_shared<T>(TimeUnit::type::SECOND);
+}
 // Random Generator Helpers
 class RandomImpl {
  protected:

Reply via email to