pitrou commented on a change in pull request #11218:
URL: https://github.com/apache/arrow/pull/11218#discussion_r721367975



##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc
##########
@@ -900,44 +896,182 @@ struct IfElseFunctor<Type, 
enable_if_fixed_size_binary<Type>> {
     auto* out_values = out->buffers[1]->mutable_data() + out->offset * 
byte_width;
 
     // copy right data to out_buff
-    const util::string_view& right_data =
-        internal::UnboxScalar<FixedSizeBinaryType>::Unbox(right);
-    if (right_data.data()) {
+    const uint8_t* right_data = UnboxBinaryScalar(right);
+    if (right_data) {
       for (int64_t i = 0; i < cond.length; i++) {
-        std::memcpy(out_values + i * byte_width, right_data.data(), 
right_data.size());
+        std::memcpy(out_values + i * byte_width, right_data, byte_width);
       }
     }
 
     // selectively copy values from left data
-    const util::string_view& left_data =
-        internal::UnboxScalar<FixedSizeBinaryType>::Unbox(left);
-
+    const uint8_t* left_data = UnboxBinaryScalar(left);
     RunIfElseLoop(cond, [&](int64_t data_offset, int64_t num_elems) {
-      if (left_data.data()) {
+      if (left_data) {
         for (int64_t i = 0; i < num_elems; i++) {
-          std::memcpy(out_values + (data_offset + i) * byte_width, 
left_data.data(),
-                      left_data.size());
+          std::memcpy(out_values + (data_offset + i) * byte_width, left_data, 
byte_width);
         }
       }
     });
 
     return Status::OK();
   }
 
-  static Result<int32_t> GetByteWidth(const DataType& left_type,
-                                      const DataType& right_type) {
-    int width = checked_cast<const 
FixedSizeBinaryType&>(left_type).byte_width();
-    if (width == checked_cast<const 
FixedSizeBinaryType&>(right_type).byte_width()) {
-      return width;
+  template <typename T = Type>
+  static enable_if_t<!is_decimal_type<T>::value, const uint8_t*> 
UnboxBinaryScalar(
+      const Scalar& scalar) {
+    return reinterpret_cast<const uint8_t*>(
+        internal::UnboxScalar<FixedSizeBinaryType>::Unbox(scalar).data());
+  }
+
+  template <typename T = Type>
+  static enable_if_decimal<T, const uint8_t*> UnboxBinaryScalar(const Scalar& 
scalar) {
+    return internal::UnboxScalar<T>::Unbox(scalar).native_endian_bytes();
+  }
+
+  template <typename T = Type>
+  static enable_if_t<!is_decimal_type<T>::value, Result<int32_t>> GetByteWidth(
+      const DataType& left_type, const DataType& right_type) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(left_type).byte_width();
+    DCHECK_EQ(width, checked_cast<const 
FixedSizeBinaryType&>(right_type).byte_width());
+    return width;
+  }
+
+  template <typename T = Type>
+  static enable_if_decimal<T, Result<int32_t>> GetByteWidth(const DataType& 
left_type,
+                                                            const DataType& 
right_type) {
+    const auto& left = checked_cast<const T&>(left_type);
+    const auto& right = checked_cast<const T&>(right_type);
+    DCHECK_EQ(left.precision(), right.precision());
+    DCHECK_EQ(left.scale(), right.scale());
+    return left.byte_width();
+  }
+};
+
+// Use builders for dictionaries - slower, but allows us to unify dictionaries
+template <typename Type>
+struct IfElseFunctor<
+    Type, enable_if_t<is_nested_type<Type>::value || 
is_dictionary_type<Type>::value>> {
+  // A - Array, S - Scalar, X = Array/Scalar
+
+  // SXX
+  static Status Call(KernelContext* ctx, const BooleanScalar& cond, const 
Datum& left,
+                     const Datum& right, Datum* out) {
+    if (left.is_scalar() && right.is_scalar()) {
+      if (cond.is_valid) {
+        *out = cond.value ? left.scalar() : right.scalar();
+      } else {
+        *out = MakeNullScalar(left.type());
+      }
+      return Status::OK();
+    }
+    // either left or right is an array. Output is always an array
+    int64_t out_arr_len = std::max(left.length(), right.length());
+    if (!cond.is_valid) {
+      // cond is null; just create a null array
+      ARROW_ASSIGN_OR_RAISE(*out,
+                            MakeArrayOfNull(left.type(), out_arr_len, 
ctx->memory_pool()))
+      return Status::OK();
+    }
+
+    const auto& valid_data = cond.value ? left : right;
+    if (valid_data.is_array()) {
+      *out = valid_data;
     } else {
-      return Status::Invalid("FixedSizeBinaryType byte_widths should be 
equal");
+      // valid data is a scalar that needs to be broadcasted
+      ARROW_ASSIGN_OR_RAISE(*out, MakeArrayFromScalar(*valid_data.scalar(), 
out_arr_len,
+                                                      ctx->memory_pool()));
     }
+    return Status::OK();
+  }
+
+  //  AAA
+  static Status Call(KernelContext* ctx, const ArrayData& cond, const 
ArrayData& left,
+                     const ArrayData& right, ArrayData* out) {
+    return RunLoop(

Review comment:
       Great, thank you!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to