pitrou commented on a change in pull request #11080:
URL: https://github.com/apache/arrow/pull/11080#discussion_r712990750



##########
File path: cpp/src/arrow/compute/kernels/codegen_internal.cc
##########
@@ -155,24 +161,44 @@ std::shared_ptr<DataType> CommonNumeric(const ValueDescr* 
begin, size_t count) {
 
 std::shared_ptr<DataType> CommonTimestamp(const std::vector<ValueDescr>& 
descrs) {
   TimeUnit::type finest_unit = TimeUnit::SECOND;
+  const std::string* timezone = nullptr;
+  bool saw_date32 = false;
+  bool saw_date64 = false;
 
   for (const auto& descr : descrs) {
     auto id = descr.type->id();
     // a common timestamp is only possible if all types are timestamp like
     switch (id) {
       case Type::DATE32:
+        // Date32's unit is days, but the coarsest we have is seconds
+        saw_date32 = true;
+        continue;
       case Type::DATE64:
+        finest_unit = std::max(finest_unit, TimeUnit::MILLI);
+        saw_date64 = true;
         continue;
-      case Type::TIMESTAMP:
-        finest_unit =
-            std::max(finest_unit, checked_cast<const 
TimestampType&>(*descr.type).unit());
+      case Type::TIMESTAMP: {
+        const auto& ty = checked_cast<const TimestampType&>(*descr.type);
+        // Don't cast to common timezone by default (may not make
+        // sense for all kernels)
+        if (timezone && *timezone != ty.timezone()) return nullptr;
+        timezone = &ty.timezone();
+        finest_unit = std::max(finest_unit, ty.unit());
         continue;
+      }
       default:
         return nullptr;
     }
   }
 
-  return timestamp(finest_unit);
+  if (timezone) {
+    // At least one timestamp seen
+    return timestamp(finest_unit, *timezone);
+  } else if (saw_date32 && saw_date64) {

Review comment:
       Hmm... if `saw_date64` is true but `saw_date32` false, we should still 
return date64, right?

##########
File path: cpp/src/arrow/compute/kernels/codegen_internal_test.cc
##########
@@ -0,0 +1,109 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+TEST(TestDispatchBest, CastBinaryDecimalArgs) {
+  std::vector<ValueDescr> args;
+  std::vector<DecimalPromotion> modes = {
+      DecimalPromotion::kAdd, DecimalPromotion::kMultiply, 
DecimalPromotion::kDivide};
+
+  // Any float -> all float
+  for (auto mode : modes) {
+    args = {decimal128(3, 2), float64()};
+    ASSERT_OK(CastBinaryDecimalArgs(mode, &args));
+    AssertTypeEqual(args[0].type, float64());
+    AssertTypeEqual(args[1].type, float64());
+  }
+
+  // Integer -> decimal with common scale
+  args = {decimal128(1, 0), int64()};
+  ASSERT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args));
+  AssertTypeEqual(args[0].type, decimal128(1, 0));
+  AssertTypeEqual(args[1].type, decimal128(19, 0));
+}
+
+TEST(TestDispatchBest, CastDecimalArgs) {

Review comment:
       Should negative scales be tested? They are allowed in Arrow.

##########
File path: cpp/src/arrow/compute/kernels/codegen_internal_test.cc
##########
@@ -0,0 +1,109 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+TEST(TestDispatchBest, CastBinaryDecimalArgs) {
+  std::vector<ValueDescr> args;
+  std::vector<DecimalPromotion> modes = {
+      DecimalPromotion::kAdd, DecimalPromotion::kMultiply, 
DecimalPromotion::kDivide};
+
+  // Any float -> all float
+  for (auto mode : modes) {
+    args = {decimal128(3, 2), float64()};
+    ASSERT_OK(CastBinaryDecimalArgs(mode, &args));
+    AssertTypeEqual(args[0].type, float64());
+    AssertTypeEqual(args[1].type, float64());
+  }
+
+  // Integer -> decimal with common scale
+  args = {decimal128(1, 0), int64()};
+  ASSERT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args));
+  AssertTypeEqual(args[0].type, decimal128(1, 0));
+  AssertTypeEqual(args[1].type, decimal128(19, 0));
+}
+
+TEST(TestDispatchBest, CastDecimalArgs) {
+  std::vector<ValueDescr> args;
+
+  // Any float -> all float
+  args = {decimal128(3, 2), float64()};
+  ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+  AssertTypeEqual(args[0].type, float64());
+  AssertTypeEqual(args[1].type, float64());
+
+  // Promote to common decimal width
+  args = {decimal128(3, 2), decimal256(3, 2)};
+  ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+  AssertTypeEqual(args[0].type, decimal256(3, 2));
+  AssertTypeEqual(args[1].type, decimal256(3, 2));
+
+  // Rescale so all have common scale/precision
+  args = {decimal128(3, 2), decimal128(3, 0)};
+  ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+  AssertTypeEqual(args[0].type, decimal128(5, 2));
+  AssertTypeEqual(args[1].type, decimal128(5, 2));
+
+  // Integer -> decimal with appropriate precision
+  args = {decimal128(3, 0), int64()};
+  ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+  AssertTypeEqual(args[0].type, decimal128(19, 0));
+  AssertTypeEqual(args[1].type, decimal128(19, 0));
+
+  args = {decimal128(3, 1), int64()};
+  ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+  AssertTypeEqual(args[0].type, decimal128(20, 1));
+  AssertTypeEqual(args[1].type, decimal128(20, 1));
+
+  // Overflow decimal128 max precision -> promote to decimal256
+  args = {decimal128(38, 0), decimal128(37, 2)};
+  ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+  AssertTypeEqual(args[0].type, decimal256(40, 2));
+  AssertTypeEqual(args[1].type, decimal256(40, 2));
+}
+
+TEST(TestDispatchBest, CommonTimestamp) {
+  AssertTypeEqual(
+      timestamp(TimeUnit::NANO),
+      CommonTimestamp({timestamp(TimeUnit::SECOND), 
timestamp(TimeUnit::NANO)}));
+  AssertTypeEqual(timestamp(TimeUnit::NANO, "UTC"),
+                  CommonTimestamp({timestamp(TimeUnit::SECOND, "UTC"),
+                                   timestamp(TimeUnit::NANO, "UTC")}));
+  AssertTypeEqual(timestamp(TimeUnit::NANO),
+                  CommonTimestamp({date32(), timestamp(TimeUnit::NANO)}));
+  AssertTypeEqual(timestamp(TimeUnit::MILLI),
+                  CommonTimestamp({date64(), timestamp(TimeUnit::SECOND)}));
+  AssertTypeEqual(date64(), CommonTimestamp({date32(), date64()}));

Review comment:
       Note this makes the name of the function a bit weird ;-)

##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
##########
@@ -1675,6 +1858,234 @@ TEST(TestCoalesce, FixedSizeBinary) {
               ArrayFromJSON(type, R"(["mno", "def", "ghi", "jkl"])"));
   CheckScalar("coalesce", {scalar1, values1},
               ArrayFromJSON(type, R"(["abc", "abc", "abc", "abc"])"));
+
+  EXPECT_RAISES_WITH_MESSAGE_THAT(
+      TypeError,
+      ::testing::HasSubstr("coalesce: all types must be identical, expected: "

Review comment:
       Should we replace "identical" with "compatible"? After all, some 
implicit casting is allowed.

##########
File path: cpp/src/arrow/compute/kernels/codegen_internal.h
##########
@@ -1278,6 +1278,9 @@ void ReplaceNullWithOtherType(std::vector<ValueDescr>* 
descrs);
 ARROW_EXPORT
 void ReplaceTypes(const std::shared_ptr<DataType>&, std::vector<ValueDescr>* 
descrs);
 
+ARROW_EXPORT
+void ReplaceTypes(const std::shared_ptr<DataType>&, ValueDescr* descrs, size_t 
count);

Review comment:
       Probably a reminder that we'd like a `std::span` backport at some point 
;-)

##########
File path: cpp/src/arrow/compute/kernels/codegen_internal.cc
##########
@@ -285,6 +311,59 @@ Status CastBinaryDecimalArgs(DecimalPromotion promotion,
   return Status::OK();
 }
 
+Status CastDecimalArgs(ValueDescr* begin, size_t count) {
+  Type::type casted_type_id = Type::DECIMAL128;
+  auto* end = begin + count;
+
+  int32_t max_scale = 0;
+  for (auto* it = begin; it != end; ++it) {
+    const auto& ty = *it->type;
+    if (is_floating(ty.id())) {
+      // Decimal + float = float
+      ReplaceTypes(float64(), begin, count);

Review comment:
       We should still examine other types in case they are incompatible, no?

##########
File path: cpp/src/arrow/compute/kernels/codegen_internal.cc
##########
@@ -155,24 +161,44 @@ std::shared_ptr<DataType> CommonNumeric(const ValueDescr* 
begin, size_t count) {
 
 std::shared_ptr<DataType> CommonTimestamp(const std::vector<ValueDescr>& 
descrs) {
   TimeUnit::type finest_unit = TimeUnit::SECOND;
+  const std::string* timezone = nullptr;
+  bool saw_date32 = false;
+  bool saw_date64 = false;
 
   for (const auto& descr : descrs) {
     auto id = descr.type->id();
     // a common timestamp is only possible if all types are timestamp like
     switch (id) {
       case Type::DATE32:
+        // Date32's unit is days, but the coarsest we have is seconds
+        saw_date32 = true;
+        continue;
       case Type::DATE64:
+        finest_unit = std::max(finest_unit, TimeUnit::MILLI);
+        saw_date64 = true;
         continue;
-      case Type::TIMESTAMP:
-        finest_unit =
-            std::max(finest_unit, checked_cast<const 
TimestampType&>(*descr.type).unit());
+      case Type::TIMESTAMP: {
+        const auto& ty = checked_cast<const TimestampType&>(*descr.type);
+        // Don't cast to common timezone by default (may not make
+        // sense for all kernels)
+        if (timezone && *timezone != ty.timezone()) return nullptr;
+        timezone = &ty.timezone();
+        finest_unit = std::max(finest_unit, ty.unit());
         continue;
+      }
       default:
         return nullptr;
     }
   }
 
-  return timestamp(finest_unit);
+  if (timezone) {
+    // At least one timestamp seen
+    return timestamp(finest_unit, *timezone);
+  } else if (saw_date32 && saw_date64) {
+    // Saw mixed date types
+    return date64();
+  }
+  return nullptr;

Review comment:
       Why not return date32 here?

##########
File path: cpp/src/arrow/compute/kernels/codegen_internal_test.cc
##########
@@ -0,0 +1,109 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <gtest/gtest.h>
+
+#include "arrow/compute/kernels/codegen_internal.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+#include "arrow/type_fwd.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+TEST(TestDispatchBest, CastBinaryDecimalArgs) {
+  std::vector<ValueDescr> args;
+  std::vector<DecimalPromotion> modes = {
+      DecimalPromotion::kAdd, DecimalPromotion::kMultiply, 
DecimalPromotion::kDivide};
+
+  // Any float -> all float
+  for (auto mode : modes) {
+    args = {decimal128(3, 2), float64()};
+    ASSERT_OK(CastBinaryDecimalArgs(mode, &args));
+    AssertTypeEqual(args[0].type, float64());
+    AssertTypeEqual(args[1].type, float64());
+  }
+
+  // Integer -> decimal with common scale
+  args = {decimal128(1, 0), int64()};
+  ASSERT_OK(CastBinaryDecimalArgs(DecimalPromotion::kAdd, &args));
+  AssertTypeEqual(args[0].type, decimal128(1, 0));
+  AssertTypeEqual(args[1].type, decimal128(19, 0));
+}
+
+TEST(TestDispatchBest, CastDecimalArgs) {
+  std::vector<ValueDescr> args;
+
+  // Any float -> all float
+  args = {decimal128(3, 2), float64()};
+  ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+  AssertTypeEqual(args[0].type, float64());
+  AssertTypeEqual(args[1].type, float64());
+
+  // Promote to common decimal width
+  args = {decimal128(3, 2), decimal256(3, 2)};
+  ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+  AssertTypeEqual(args[0].type, decimal256(3, 2));
+  AssertTypeEqual(args[1].type, decimal256(3, 2));
+
+  // Rescale so all have common scale/precision
+  args = {decimal128(3, 2), decimal128(3, 0)};
+  ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+  AssertTypeEqual(args[0].type, decimal128(5, 2));
+  AssertTypeEqual(args[1].type, decimal128(5, 2));
+
+  // Integer -> decimal with appropriate precision
+  args = {decimal128(3, 0), int64()};
+  ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+  AssertTypeEqual(args[0].type, decimal128(19, 0));
+  AssertTypeEqual(args[1].type, decimal128(19, 0));
+
+  args = {decimal128(3, 1), int64()};
+  ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+  AssertTypeEqual(args[0].type, decimal128(20, 1));
+  AssertTypeEqual(args[1].type, decimal128(20, 1));
+
+  // Overflow decimal128 max precision -> promote to decimal256
+  args = {decimal128(38, 0), decimal128(37, 2)};
+  ASSERT_OK(CastDecimalArgs(args.data(), args.size()));
+  AssertTypeEqual(args[0].type, decimal256(40, 2));
+  AssertTypeEqual(args[1].type, decimal256(40, 2));
+}
+
+TEST(TestDispatchBest, CommonTimestamp) {
+  AssertTypeEqual(
+      timestamp(TimeUnit::NANO),
+      CommonTimestamp({timestamp(TimeUnit::SECOND), 
timestamp(TimeUnit::NANO)}));
+  AssertTypeEqual(timestamp(TimeUnit::NANO, "UTC"),
+                  CommonTimestamp({timestamp(TimeUnit::SECOND, "UTC"),
+                                   timestamp(TimeUnit::NANO, "UTC")}));
+  AssertTypeEqual(timestamp(TimeUnit::NANO),
+                  CommonTimestamp({date32(), timestamp(TimeUnit::NANO)}));
+  AssertTypeEqual(timestamp(TimeUnit::MILLI),
+                  CommonTimestamp({date64(), timestamp(TimeUnit::SECOND)}));
+  AssertTypeEqual(date64(), CommonTimestamp({date32(), date64()}));
+  ASSERT_EQ(nullptr, CommonTimestamp({date32(), date32()}));

Review comment:
       But why?

##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc
##########
@@ -2063,51 +2146,178 @@ struct CoalesceFunctor<Type, 
enable_if_base_binary<Type>> {
   }
 
   static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* 
out) {
-    // Special case: grab any leading non-null scalar or array arguments
+    return ExecVarWidthCoalesceImpl(
+        ctx, batch, out,
+        [&](ArrayBuilder* builder) {
+          int64_t reservation = 0;
+          for (const auto& datum : batch.values) {
+            if (datum.is_array()) {
+              const ArrayType array(datum.array());
+              reservation = std::max<int64_t>(reservation, 
array.total_values_length());
+            } else {
+              const auto& scalar = *datum.scalar();
+              if (scalar.is_valid) {
+                const int64_t size = UnboxScalar<Type>::Unbox(scalar).size();
+                reservation = std::max<int64_t>(reservation, batch.length * 
size);
+              }
+            }
+          }
+          return checked_cast<BuilderType*>(builder)->ReserveData(reservation);
+        },
+        [&](ArrayBuilder* builder, const Scalar& scalar) {
+          return checked_cast<BuilderType*>(builder)->Append(
+              UnboxScalar<Type>::Unbox(scalar));
+        });
+  }
+};
+
+template <>
+struct CoalesceFunctor<FixedSizeListType> {

Review comment:
       It seems you could simply have:
   ```c++
   template <typename Type>
   struct CoalesceFunctor<Type,
       std::enable_if<is_nested_type<Type>::value && 
!is_union_type<Type>::value>::type> {
     // common implementation for non-union nested types
   };
   ```
   

##########
File path: cpp/src/arrow/compute/kernels/codegen_internal.cc
##########
@@ -285,6 +311,59 @@ Status CastBinaryDecimalArgs(DecimalPromotion promotion,
   return Status::OK();
 }
 
+Status CastDecimalArgs(ValueDescr* begin, size_t count) {
+  Type::type casted_type_id = Type::DECIMAL128;
+  auto* end = begin + count;
+
+  int32_t max_scale = 0;
+  for (auto* it = begin; it != end; ++it) {
+    const auto& ty = *it->type;
+    if (is_floating(ty.id())) {
+      // Decimal + float = float
+      ReplaceTypes(float64(), begin, count);
+      return Status::OK();
+    } else if (is_integer(ty.id())) {
+      // Nothing to do here
+    } else if (is_decimal(ty.id())) {
+      max_scale = std::max(max_scale, checked_cast<const 
DecimalType&>(ty).scale());
+      if (ty.id() == Type::DECIMAL256) {
+        casted_type_id = Type::DECIMAL256;
+      }
+    } else {
+      // Non-numeric, can't cast
+      return Status::OK();
+    }
+  }
+
+  // All integer and decimal, rescale
+  int32_t common_precision = 0;
+  for (auto* it = begin; it != end; ++it) {
+    const auto& ty = *it->type;
+    if (is_integer(ty.id())) {
+      ARROW_ASSIGN_OR_RAISE(auto precision, 
MaxDecimalDigitsForInteger(ty.id()));
+      precision += max_scale;
+      common_precision = std::max(common_precision, precision);
+    } else if (is_decimal(ty.id())) {
+      const auto& decimal_ty = checked_cast<const DecimalType&>(ty);
+      auto precision = decimal_ty.precision();
+      const auto scale = decimal_ty.scale();
+      precision += max_scale - scale;
+      common_precision = std::max(common_precision, precision);
+    }
+  }
+
+  if (common_precision > BasicDecimal128::kMaxPrecision) {

Review comment:
       Should there be an error in case `BasicDecimal256::kMaxPrecision` is 
exceeded?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to