jorisvandenbossche commented on a change in pull request #10557:
URL: https://github.com/apache/arrow/pull/10557#discussion_r670578004



##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc
##########
@@ -676,7 +677,353 @@ void AddPrimitiveIfElseKernels(const 
std::shared_ptr<ScalarFunction>& scalar_fun
   }
 }
 
-}  // namespace
+// Helper to copy or broadcast fixed-width values between buffers.
+template <typename Type, typename Enable = void>
+struct CopyFixedWidth {};
+template <>
+struct CopyFixedWidth<BooleanType> {
+  static void CopyScalar(const Scalar& scalar, const int64_t length,
+                         uint8_t* raw_out_values, const int64_t out_offset) {
+    const bool value = UnboxScalar<BooleanType>::Unbox(scalar);
+    BitUtil::SetBitsTo(raw_out_values, out_offset, length, value);
+  }
+  static void CopyArray(const DataType&, const uint8_t* in_values,
+                        const int64_t in_offset, const int64_t length,
+                        uint8_t* raw_out_values, const int64_t out_offset) {
+    arrow::internal::CopyBitmap(in_values, in_offset, length, raw_out_values, 
out_offset);
+  }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_number<Type>> {
+  using CType = typename TypeTraits<Type>::CType;
+  static void CopyScalar(const Scalar& scalar, const int64_t length,
+                         uint8_t* raw_out_values, const int64_t out_offset) {
+    CType* out_values = reinterpret_cast<CType*>(raw_out_values);
+    const CType value = UnboxScalar<Type>::Unbox(scalar);
+    std::fill(out_values + out_offset, out_values + out_offset + length, 
value);
+  }
+  static void CopyArray(const DataType&, const uint8_t* in_values,
+                        const int64_t in_offset, const int64_t length,
+                        uint8_t* raw_out_values, const int64_t out_offset) {
+    std::memcpy(raw_out_values + out_offset * sizeof(CType),
+                in_values + in_offset * sizeof(CType), length * sizeof(CType));
+  }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_same<Type, FixedSizeBinaryType>> {
+  static void CopyScalar(const Scalar& values, const int64_t length,
+                         uint8_t* raw_out_values, const int64_t out_offset) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width();
+    uint8_t* next = raw_out_values + (width * out_offset);
+    const auto& scalar = checked_cast<const FixedSizeBinaryScalar&>(values);
+    // Scalar may have null value buffer
+    if (!scalar.value) return;
+    DCHECK_EQ(scalar.value->size(), width);
+    for (int i = 0; i < length; i++) {
+      std::memcpy(next, scalar.value->data(), width);
+      next += width;
+    }
+  }
+  static void CopyArray(const DataType& type, const uint8_t* in_values,
+                        const int64_t in_offset, const int64_t length,
+                        uint8_t* raw_out_values, const int64_t out_offset) {
+    const int32_t width = checked_cast<const 
FixedSizeBinaryType&>(type).byte_width();
+    uint8_t* next = raw_out_values + (width * out_offset);
+    std::memcpy(next, in_values + in_offset * width, length * width);
+  }
+};
+template <typename Type>
+struct CopyFixedWidth<Type, enable_if_decimal<Type>> {
+  using ScalarType = typename TypeTraits<Type>::ScalarType;
+  static void CopyScalar(const Scalar& values, const int64_t length,
+                         uint8_t* raw_out_values, const int64_t out_offset) {
+    const int32_t width =
+        checked_cast<const FixedSizeBinaryType&>(*values.type).byte_width();
+    uint8_t* next = raw_out_values + (width * out_offset);
+    const auto& scalar = checked_cast<const ScalarType&>(values);
+    const auto value = scalar.value.ToBytes();
+    for (int i = 0; i < length; i++) {
+      std::memcpy(next, value.data(), width);
+      next += width;
+    }
+  }
+  static void CopyArray(const DataType& type, const uint8_t* in_values,
+                        const int64_t in_offset, const int64_t length,
+                        uint8_t* raw_out_values, const int64_t out_offset) {
+    const int32_t width = checked_cast<const 
FixedSizeBinaryType&>(type).byte_width();
+    uint8_t* next = raw_out_values + (width * out_offset);
+    std::memcpy(next, in_values + in_offset * width, length * width);
+  }
+};
+// Copy fixed-width values from a scalar/array datum into an output values 
buffer
+template <typename Type>
+void CopyValues(const Datum& in_values, const int64_t in_offset, const int64_t 
length,
+                uint8_t* out_valid, uint8_t* out_values, const int64_t 
out_offset) {
+  if (in_values.is_scalar()) {
+    const auto& scalar = *in_values.scalar();
+    if (out_valid) {
+      BitUtil::SetBitsTo(out_valid, out_offset, length, scalar.is_valid);
+    }
+    CopyFixedWidth<Type>::CopyScalar(scalar, length, out_values, out_offset);
+  } else {
+    const ArrayData& array = *in_values.array();
+    if (out_valid) {
+      if (array.MayHaveNulls()) {
+        if (length == 1) {
+          // CopyBitmap is slow for short runs
+          BitUtil::SetBitTo(
+              out_valid, out_offset,
+              BitUtil::GetBit(array.buffers[0]->data(), array.offset + 
in_offset));
+        } else {
+          arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + 
in_offset,
+                                      length, out_valid, out_offset);
+        }
+      } else {
+        BitUtil::SetBitsTo(out_valid, out_offset, length, true);
+      }
+    }
+    CopyFixedWidth<Type>::CopyArray(*array.type, array.buffers[1]->data(),
+                                    array.offset + in_offset, length, 
out_values,
+                                    out_offset);
+  }
+}
+
+struct CaseWhenFunction : ScalarFunction {
+  using ScalarFunction::ScalarFunction;
+
+  Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const 
override {
+    // The first function is a struct of booleans, where the number of fields 
in the
+    // struct is either equal to the number of other arguments or is one less.
+    RETURN_NOT_OK(CheckArity(*values));
+    EnsureDictionaryDecoded(values);
+    auto first_type = (*values)[0].type;
+    if (first_type->id() != Type::STRUCT) {
+      return Status::TypeError("case_when: first argument must be STRUCT, not 
",
+                               *first_type);
+    }
+    auto num_fields = static_cast<size_t>(first_type->num_fields());
+    if (num_fields < values->size() - 2 || num_fields >= values->size()) {
+      return Status::Invalid(
+          "case_when: number of struct fields must be equal to or one less 
than count of "
+          "remaining arguments (",
+          values->size() - 1, "), got: ", first_type->num_fields());
+    }
+    for (const auto& field : first_type->fields()) {
+      if (field->type()->id() != Type::BOOL) {
+        return Status::TypeError(
+            "case_when: all fields of first argument must be BOOL, but ", 
field->name(),
+            " was of type: ", *field->type());
+      }
+    }
+
+    if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) {
+      for (auto it = values->begin() + 1; it != values->end(); it++) {
+        it->type = type;
+      }
+    }
+    if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+    return arrow::compute::detail::NoMatchingKernel(this, *values);
+  }
+};
+
+// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar 
conditions
+template <typename Type>
+Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* 
out) {
+  const auto& conds = checked_cast<const 
StructScalar&>(*batch.values[0].scalar());
+  if (!conds.is_valid) {
+    return Status::Invalid("cond struct must not be null");
+  }
+  Datum result;
+  for (size_t i = 0; i < batch.values.size() - 1; i++) {
+    if (i < conds.value.size()) {
+      const Scalar& cond = *conds.value[i];
+      if (cond.is_valid && internal::UnboxScalar<BooleanType>::Unbox(cond)) {
+        result = batch[i + 1];
+        break;
+      }
+    } else {
+      // ELSE clause
+      result = batch[i + 1];
+      break;
+    }
+  }
+  if (out->is_scalar()) {
+    *out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type());
+    return Status::OK();
+  }
+  ArrayData* output = out->mutable_array();
+  if (!result.is_value()) {
+    // All conditions false, no 'else' argument
+    result = MakeNullScalar(out->type());
+  }
+  CopyValues<Type>(result, /*in_offset=*/0, batch.length,
+                   output->GetMutableValues<uint8_t>(0, 0),
+                   output->GetMutableValues<uint8_t>(1, 0), output->offset);
+  return Status::OK();
+}
+
+// Implement 'case when' for any mix of scalar/array arguments for any 
fixed-width type,
+// given helper functions to copy data from a source array to a target array
+template <typename Type>
+Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* 
out) {
+  const auto& conds_array = *batch.values[0].array();
+  if (conds_array.GetNullCount() > 0) {
+    return Status::Invalid("cond struct must not have nulls");

Review comment:
       ```suggestion
       return Status::Invalid("cond struct must not have top-level nulls");
   ```
   
   ? (not fully sure if this is the correct terminology, but just to clarify 
about which nulls this is, as each individual condition field can still have 
nulls)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to