lidavidm commented on a change in pull request #10557:
URL: https://github.com/apache/arrow/pull/10557#discussion_r661730002



##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc
##########
@@ -676,7 +677,351 @@ void AddPrimitiveIfElseKernels(const 
std::shared_ptr<ScalarFunction>& scalar_fun
   }
 }
 
-}  // namespace
+// Helper to copy or broadcast fixed-width values between buffers.
+template <typename Type, typename Enable = void>
+struct CopyFixedWidth {};
+template <>
+struct CopyFixedWidth<BooleanType> {
+  static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const 
int64_t offset,
+                         const int64_t length) {
+    const bool value = UnboxScalar<BooleanType>::Unbox(scalar);
+    BitUtil::SetBitsTo(out_values, offset, length, value);
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* out_values, const 
int64_t offset,
+                        const int64_t length) {
+    arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + 
offset, length,
+                                out_values, offset);
+  }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_number<Type>> {
+  using CType = typename TypeTraits<Type>::CType;
+  static void CopyScalar(const Scalar& values, uint8_t* raw_out_values,
+                         const int64_t offset, const int64_t length) {
+    CType* out_values = reinterpret_cast<CType*>(raw_out_values);
+    const CType value = UnboxScalar<Type>::Unbox(values);
+    std::fill(out_values + offset, out_values + offset + length, value);
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* raw_out_values,
+                        const int64_t offset, const int64_t length) {
+    CType* out_values = reinterpret_cast<CType*>(raw_out_values);
+    const CType* in_values = array.GetValues<CType>(1);
+    std::copy(in_values + offset, in_values + offset + length, out_values + 
offset);
+  }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> {
+  static void CopyScalar(const Scalar& values, uint8_t* out_values, const 
int64_t offset,
+                         const int64_t length) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width();
+    uint8_t* next = out_values + (width * offset);
+    const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values);
+    // Scalar may have null value buffer
+    if (!scalar.value) return;
+    DCHECK_EQ(scalar.value->size(), width);
+    for (int i = 0; i < length; i++) {
+      std::memcpy(next, scalar.value->data(), width);
+      next += width;
+    }
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* out_values, const 
int64_t offset,
+                        const int64_t length) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width();
+    uint8_t* next = out_values + (width * offset);
+    const auto* in_values = array.GetValues<uint8_t>(1, (offset + 
array.offset) * width);
+    std::memcpy(next, in_values, length * width);
+  }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_decimal<Type>> {
+  using ScalarType = typename TypeTraits<Type>::ScalarType;
+  static void CopyScalar(const Scalar& values, uint8_t* out_values, const 
int64_t offset,
+                         const int64_t length) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width();
+    uint8_t* next = out_values + (width * offset);
+    const auto& scalar = checked_cast<const ScalarType&>(values);
+    const auto value = scalar.value.ToBytes();
+    for (int i = 0; i < length; i++) {
+      std::memcpy(next, value.data(), width);
+      next += width;
+    }
+  }
+  static void CopyArray(const ArrayData& array, uint8_t* out_values, const 
int64_t offset,
+                        const int64_t length) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*array.type).byte_width();
+    uint8_t* next = out_values + (width * offset);
+    const auto* in_values = array.GetValues<uint8_t>(1, (offset + 
array.offset) * width);
+    std::memcpy(next, in_values, length * width);
+  }
+};
+// Copy fixed-width values from a scalar/array datum into an output values 
buffer
+template <typename Type>
+void CopyValues(const Datum& values, uint8_t* out_valid, uint8_t* out_values,
+                const int64_t offset, const int64_t length) {
+  using Copier = CopyFixedWidth<Type>;
+  if (values.is_scalar()) {
+    const auto& scalar = *values.scalar();
+    if (out_valid) {
+      BitUtil::SetBitsTo(out_valid, offset, length, scalar.is_valid);
+    }
+    Copier::CopyScalar(scalar, out_values, offset, length);
+  } else {
+    const ArrayData& array = *values.array();
+    if (out_valid) {
+      if (array.MayHaveNulls()) {
+        arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + 
offset,
+                                    length, out_valid, offset);
+      } else {
+        BitUtil::SetBitsTo(out_valid, offset, length, true);
+      }
+    }
+    Copier::CopyArray(array, out_values, offset, length);
+  }
+}
+
+struct CaseWhenFunction : ScalarFunction {
+  using ScalarFunction::ScalarFunction;
+
+  Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const 
override {
+    RETURN_NOT_OK(CheckArity(*values));
+    std::vector<ValueDescr> value_types;
+    for (size_t i = 0; i < values->size() - 1; i += 2) {
+      ValueDescr* cond = &(*values)[i];
+      if (cond->type->id() == Type::NA) {
+        cond->type = boolean();
+      }
+      if (cond->type->id() != Type::BOOL) {
+        return Status::Invalid("Condition arguments must be boolean, but 
argument ", i,
+                               " was ", cond->type->ToString());
+      }
+      value_types.push_back((*values)[i + 1]);
+    }
+    if (values->size() % 2 != 0) {
+      // Have an ELSE clause
+      value_types.push_back(values->back());
+    }
+    EnsureDictionaryDecoded(&value_types);
+    if (auto type = CommonNumeric(value_types)) {
+      ReplaceTypes(type, &value_types);
+    }
+
+    const DataType& common_values_type = *value_types.front().type;
+    auto next_type = value_types.cbegin();
+    for (size_t i = 0; i < values->size(); i += 2) {
+      if (!common_values_type.Equals(next_type->type)) {
+        return Status::Invalid("Value arguments must be of same type, but 
argument ", i,
+                               " was ", next_type->type->ToString(), " 
(expected ",
+                               common_values_type.ToString(), ")");
+      }
+      if (i == values->size() - 1) {
+        // ELSE
+        (*values)[i] = *next_type++;
+      } else {
+        (*values)[i + 1] = *next_type++;
+      }
+    }
+
+    // We register a unary kernel for each value type and dispatch to it after 
validation.
+    if (auto kernel = DispatchExactImpl(this, {values->back()})) return kernel;
+    return arrow::compute::detail::NoMatchingKernel(this, *values);
+  }
+};
+
+// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar 
arguments
+Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* 
out) {
+  for (size_t i = 0; i < batch.values.size() - 1; i += 2) {
+    const Scalar& cond = *batch[i].scalar();
+    if (cond.is_valid && internal::UnboxScalar<BooleanType>::Unbox(cond)) {
+      *out = batch[i + 1];
+      return Status::OK();
+    }
+  }
+  if (batch.values.size() % 2 == 0) {
+    // No ELSE
+    *out = MakeNullScalar(batch[1].type());
+  } else {
+    *out = batch.values.back();
+  }
+  return Status::OK();
+}
+
+// Implement 'case when' for any mix of scalar/array arguments for any 
fixed-width type,
+// given helper functions to copy data from a source array to a target array 
and to
+// allocate a values buffer
+template <typename Type>
+Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* 
out) {
+  ArrayData* output = out->mutable_array();
+  const bool have_else_arg = batch.values.size() % 2 != 0;
+  // Check if we may need a validity bitmap
+  uint8_t* out_valid = nullptr;
+
+  bool need_valid_bitmap = false;
+  if (!have_else_arg) {
+    // If we don't have an else arg -> need a bitmap since we may emit nulls
+    need_valid_bitmap = true;
+  } else if (batch.values.back().null_count() > 0) {
+    // If the 'else' array has a null count we need a validity bitmap
+    need_valid_bitmap = true;
+  } else {
+    // Otherwise if any value array has a null count we need a validity bitmap
+    for (size_t i = 1; i < batch.values.size(); i += 2) {
+      if (batch[i].null_count() > 0) {
+        need_valid_bitmap = true;
+        break;
+      }
+    }
+  }
+  if (need_valid_bitmap) {
+    ARROW_ASSIGN_OR_RAISE(output->buffers[0], 
ctx->AllocateBitmap(batch.length));
+    out_valid = output->buffers[0]->mutable_data();
+  }
+
+  // Initialize values buffer
+  uint8_t* out_values = output->buffers[1]->mutable_data();
+  if (have_else_arg) {
+    // Copy 'else' value into output

Review comment:
       This seems to be faster as written, oddly.
   
   Before:
   
   ```
   
-----------------------------------------------------------------------------------------------
   Benchmark                                     Time             CPU   
Iterations UserCounters...
   
-----------------------------------------------------------------------------------------------
   CaseWhenBench32/1048576/0              31933112 ns     31932368 ns           
22 bytes_per_second=504.974M/s
   CaseWhenBench64/1048576/0              33170481 ns     33168735 ns           
21 bytes_per_second=968.533M/s
   CaseWhenBench32/1048576/99             32487300 ns     32487411 ns           
21 bytes_per_second=496.299M/s
   CaseWhenBench64/1048576/99             33682029 ns     33680901 ns           
21 bytes_per_second=953.715M/s
   CaseWhenBench32Contiguous/1048576/0     7255445 ns      7255387 ns           
96 bytes_per_second=1.632G/s
   CaseWhenBench64Contiguous/1048576/0     7932437 ns      7932171 ns           
88 bytes_per_second=2.97013G/s
   CaseWhenBench32Contiguous/1048576/99    7526742 ns      7526709 ns           
92 bytes_per_second=1.57303G/s
   CaseWhenBench64Contiguous/1048576/99    8172498 ns      8172239 ns           
83 bytes_per_second=2.88261G/s
   ```
   
   After:
   
   ```
   
-----------------------------------------------------------------------------------------------
   Benchmark                                     Time             CPU   
Iterations UserCounters...
   
-----------------------------------------------------------------------------------------------
   CaseWhenBench32/1048576/0              44166172 ns     44165634 ns           
16 bytes_per_second=365.103M/s
   CaseWhenBench64/1048576/0              44605356 ns     44603995 ns           
16 bytes_per_second=720.227M/s
   CaseWhenBench32/1048576/99             44867670 ns     44867051 ns           
16 bytes_per_second=359.361M/s
   CaseWhenBench64/1048576/99             45077818 ns     45076721 ns           
15 bytes_per_second=712.607M/s
   CaseWhenBench32Contiguous/1048576/0    17757494 ns     17757271 ns           
39 bytes_per_second=682.819M/s
   CaseWhenBench64Contiguous/1048576/0    18236327 ns     18235892 ns           
38 bytes_per_second=1.29193G/s
   CaseWhenBench32Contiguous/1048576/99   15051281 ns     15051008 ns           
39 bytes_per_second=805.518M/s
   CaseWhenBench64Contiguous/1048576/99   15504081 ns     15503998 ns           
45 bytes_per_second=1.51944G/s
   ```
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to