lidavidm commented on pull request #12338:
URL: https://github.com/apache/arrow/pull/12338#issuecomment-1047834425


   Yes, all the kernels need to be added in the same place. However, we don't 
need to add new kernels, we can branch within the existing kernel. See this 
sketch (incomplete):
   
   <details>
   
   ```diff
   diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc 
b/cpp/src/arrow/compute/kernels/scalar_compare.cc
   index fc8527810..7ee887233 100644
   --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc
   +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc
   @@ -70,24 +70,24 @@ struct GreaterEqual {
    
    struct Minimum {
      template <typename T, typename Arg0, typename Arg1>
   -  static enable_if_floating_value<T> Call(Arg0 left, Arg1 right) {
   +  static enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1 
right, Status*) {
        static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, 
Arg1>::value, "");
        return std::fmin(left, right);
      }
    
      template <typename T, typename Arg0, typename Arg1>
   -  static enable_if_integer_value<T> Call(Arg0 left, Arg1 right) {
   +  static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 
right, Status*) {
        static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, 
Arg1>::value, "");
        return std::min(left, right);
      }
    
      template <typename T, typename Arg0, typename Arg1>
   -  static enable_if_decimal_value<T> Call(Arg0 left, Arg1 right) {
   +  static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 
right, Status*) {
        static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, 
Arg1>::value, "");
        return std::min(left, right);
      }
    
   -  static string_view Call(string_view left, string_view right) {
   +  static string_view Call(KernelContext*, string_view left, string_view 
right, Status*) {
        return std::min(left, right);
      }
    
   @@ -114,24 +114,24 @@ struct Minimum {
    
    struct Maximum {
      template <typename T, typename Arg0, typename Arg1>
   -  static enable_if_floating_value<T> Call(Arg0 left, Arg1 right) {
   +  static enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1 
right, Status*) {
        static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, 
Arg1>::value, "");
        return std::fmax(left, right);
      }
    
      template <typename T, typename Arg0, typename Arg1>
   -  static enable_if_integer_value<T> Call(Arg0 left, Arg1 right) {
   +  static enable_if_integer_value<T> Call(KernelContext*, Arg0 left, Arg1 
right, Status*) {
        static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, 
Arg1>::value, "");
        return std::max(left, right);
      }
    
      template <typename T, typename Arg0, typename Arg1>
   -  static enable_if_decimal_value<T> Call(Arg0 left, Arg1 right) {
   +  static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 
right, Status*) {
        static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, 
Arg1>::value, "");
        return std::max(left, right);
      }
    
   -  static string_view Call(string_view left, string_view right) {
   +  static string_view Call(KernelContext*, string_view left, string_view 
right, Status*) {
        return std::max(left, right);
      }
    
   @@ -330,11 +330,12 @@ template <typename OutType, typename Op>
    struct ScalarMinMax {
      using OutValue = typename GetOutputType<OutType>::T;
    
   -  static void ExecScalar(const ExecBatch& batch,
   +  static Status ExecScalar(KernelContext* ctx, const ExecBatch& batch,
                             const ElementWiseAggregateOptions& options, 
Scalar* out) {
        // All arguments are scalar
        OutValue value{};
        bool valid = false;
   +    Status st = Status::OK();
        for (const auto& arg : batch.values) {
          // Ignore non-scalar arguments so we can use it in the 
mixed-scalar-and-array case
          if (!arg.is_scalar()) continue;
   @@ -342,31 +343,63 @@ struct ScalarMinMax {
          if (!scalar.is_valid) {
            if (options.skip_nulls) continue;
            out->is_valid = false;
   -        return;
   +        return st;
          }
          if (!valid) {
            value = UnboxScalar<OutType>::Unbox(scalar);
            valid = true;
          } else {
            value = Op::template Call<OutValue, OutValue, OutValue>(
   -            value, UnboxScalar<OutType>::Unbox(scalar));
   +            ctx, value, UnboxScalar<OutType>::Unbox(scalar), &st);
   +        RETURN_NOT_OK(st);
          }
        }
        out->is_valid = valid;
        if (valid) {
          BoxScalar<OutType>::Box(value, out);
        }
   +    return st;
   +  }
   +
   +  // Specialization for 2-ary case with skip_nulls
   +  static Status ExecBinary(KernelContext* ctx, const ExecBatch& batch, 
Datum* out) {
   +    if (out->is_scalar()) {
   +      out->scalar()->is_valid = batch[0].scalar()->is_valid && 
batch[1].scalar()->is_valid;
   +    } else if (batch[0].is_scalar()) {
   +      ArrayData* output = out->mutable_array();
   +      // TODO: if scalar is invalid, allocate all-null bitmap; else:
   +      // if batch[1].MayHaveNulls(), copy its bitmap (batch[1] should
   +      // not be a scalar since then we'd be in the first case above)
   +    } else if (batch[1].is_scalar()) {
   +      // TODO:
   +    } else {
   +      ArrayData* output = out->mutable_array();
   +      const ArrayData& left = *batch[0].array();
   +      const ArrayData& right = *batch[1].array();
   +      // TODO: note that left and/or right may not have a validity
   +      // bitmap
   +      ARROW_ASSIGN_OR_RAISE(output->buffers[0], 
::arrow::internal::BitmapAnd(
   +          ctx->memory_pool(),
   +          left.GetValues<uint8_t>(0, 0), left.offset,
   +          right.GetValues<uint8_t>(0, 0), right.offset,
   +          batch.length, /*out_offset=*/0));
   +    }
   +    return applicator::ScalarBinaryEqualTypes<OutType, OutType, 
Op>::Exec(ctx, batch, out);
      }
    
      static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* 
out) {
        const ElementWiseAggregateOptions& options = MinMaxState::Get(ctx);
   +
   +    if (batch.num_values() == 2 && options.skip_nulls) {
   +      return ExecBinary(ctx, batch, out);
   +    }
   +
        const auto descrs = batch.GetDescriptors();
        const size_t scalar_count =
            static_cast<size_t>(std::count_if(batch.values.begin(), 
batch.values.end(),
                                              [](const Datum& d) { return 
d.is_scalar(); }));
        if (scalar_count == batch.values.size()) {
   -      ExecScalar(batch, options, out->scalar().get());
   -      return Status::OK();
   +      return ExecScalar(ctx, batch, options, out->scalar().get());
        }
    
        ArrayData* output = out->mutable_array();
   @@ -382,7 +415,7 @@ struct ScalarMinMax {
        if (scalar_count > 0) {
          ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> temp_scalar,
                                MakeScalar(out->type(), 0));
   -      ExecScalar(batch, options, temp_scalar.get());
   +      RETURN_NOT_OK(ExecScalar(ctx, batch, options, temp_scalar.get()));
          if (temp_scalar->is_valid) {
            const auto value = UnboxScalar<OutType>::Unbox(*temp_scalar);
            initialize_output = false;
   @@ -444,6 +477,7 @@ struct ScalarMinMax {
          }
        }
    
   +    Status st = Status::OK();
        for (const auto& array : arrays) {
          OutputArrayWriter<OutType> writer(out->mutable_array());
          ArrayIterator<OutType> out_it(*output);
   @@ -454,7 +488,7 @@ struct ScalarMinMax {
                auto u = out_it();
                if (!output->buffers[0] ||
                    bit_util::GetBit(output->buffers[0]->data(), index)) {
   -              writer.Write(Op::template Call<OutValue, OutValue, 
OutValue>(u, value));
   +              writer.Write(Op::template Call<OutValue, OutValue, 
OutValue>(ctx, u, value, &st));
                } else {
                  writer.Write(value);
                }
   @@ -468,7 +502,7 @@ struct ScalarMinMax {
              });
        }
        output->null_count = output->buffers[0] ? -1 : 0;
   -    return Status::OK();
   +    return st;
      }
    };
    
   @@ -492,6 +526,7 @@ Status ExecBinaryMinMaxScalar(KernelContext* ctx,
      const auto& first_scalar = *batch.values.front().scalar();
      string_view result = checked_cast<const 
BaseBinaryScalar&>(first_scalar).view();
      bool valid = first_scalar.is_valid;
   +  Status st = Status::OK();
      for (size_t i = 1; i < batch.values.size(); i++) {
        const auto& scalar = *batch[i].scalar();
        if (!scalar.is_valid) {
   @@ -499,7 +534,7 @@ Status ExecBinaryMinMaxScalar(KernelContext* ctx,
          continue;
        } else {
          string_view value = checked_cast<const 
BaseBinaryScalar&>(scalar).view();
   -      result = !valid ? value : Op::Call(result, value);
   +      result = !valid ? value : Op::Call(ctx, result, value, &st);
          valid = true;
        }
      }
   @@ -510,7 +545,7 @@ Status ExecBinaryMinMaxScalar(KernelContext* ctx,
      } else {
        output->is_valid = false;
      }
   -  return Status::OK();
   +  return st;
    }
    
    template <typename Type, typename Op>
   @@ -537,10 +572,11 @@ struct BinaryScalarMinMax {
        RETURN_NOT_OK(builder.Reserve(batch.length));
        RETURN_NOT_OK(builder.ReserveData(estimated_final_size));
    
   +    Status st = Status::OK();
        for (int64_t row = 0; row < batch.length; row++) {
          util::optional<string_view> result;
          auto visit_value = [&](string_view value) {
   -        result = !result ? value : Op::Call(*result, value);
   +        result = !result ? value : Op::Call(ctx, *result, value, &st);
          };
    
          for (size_t col = 0; col < batch.values.size(); col++) {
   @@ -548,6 +584,7 @@ struct BinaryScalarMinMax {
              const auto& scalar = *batch[col].scalar();
              if (scalar.is_valid) {
                visit_value(UnboxScalar<Type>::Unbox(scalar));
   +            RETURN_NOT_OK(st);
              } else if (!options.skip_nulls) {
                result = util::nullopt;
                break;
   @@ -561,6 +598,7 @@ struct BinaryScalarMinMax {
                const int64_t length = offsets[row + 1] - offsets[row];
                visit_value(
                    string_view(reinterpret_cast<const char*>(data + 
offsets[row]), length));
   +            RETURN_NOT_OK(st);
              } else if (!options.skip_nulls) {
                result = util::nullopt;
                break;
   @@ -629,10 +667,11 @@ struct FixedSizeBinaryScalarMinMax {
        RETURN_NOT_OK(builder.ReserveData(estimated_final_size));
    
        std::vector<string_view> valid_cols(batch.values.size());
   +    Status st = Status::OK();
        for (int64_t row = 0; row < batch.length; row++) {
          string_view result;
          auto visit_value = [&](string_view value) {
   -        result = result.empty() ? value : Op::Call(result, value);
   +        result = result.empty() ? value : Op::Call(ctx, result, value, &st);
          };
    
          for (size_t col = 0; col < batch.values.size(); col++) {
   @@ -640,6 +679,7 @@ struct FixedSizeBinaryScalarMinMax {
              const auto& scalar = *batch[col].scalar();
              if (scalar.is_valid) {
                visit_value(UnboxScalar<FixedSizeBinaryType>::Unbox(scalar));
   +            RETURN_NOT_OK(st);
              } else if (!options.skip_nulls) {
                result = string_view();
                break;
   @@ -651,6 +691,7 @@ struct FixedSizeBinaryScalarMinMax {
                const auto data = array.GetValues<uint8_t>(1, 
/*absolute_offset=*/0);
                visit_value(string_view(
                    reinterpret_cast<const char*>(data) + row * byte_width, 
byte_width));
   +            RETURN_NOT_OK(st);
              } else if (!options.skip_nulls) {
                result = string_view();
                break;
   @@ -699,12 +740,14 @@ std::shared_ptr<ScalarFunction> 
MakeScalarMinMax(std::string name,
      auto func = std::make_shared<VarArgsCompareFunction>(
          name, Arity::VarArgs(), doc, &default_element_wise_aggregate_options);
      for (const auto& ty : NumericTypes()) {
   -    auto exec = GeneratePhysicalNumeric<ScalarMinMax, Op>(ty);
   -    ScalarKernel kernel{KernelSignature::Make({ty}, ty, 
/*is_varargs=*/true), exec,
   -                        MinMaxState::Init};
   -    kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE;
   -    kernel.mem_allocation = MemAllocation::type::PREALLOCATE;
   -    DCHECK_OK(func->AddKernel(std::move(kernel)));
   +    {
   +      auto exec = GeneratePhysicalNumeric<ScalarMinMax, Op>(ty);
   +      ScalarKernel kernel{KernelSignature::Make({ty}, ty, 
/*is_varargs=*/true), exec,
   +        MinMaxState::Init};
   +      kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE;
   +      kernel.mem_allocation = MemAllocation::type::PREALLOCATE;
   +      DCHECK_OK(func->AddKernel(std::move(kernel)));
   +    }
      }
      for (const auto& ty : TemporalTypes()) {
        auto exec = GeneratePhysicalNumeric<ScalarMinMax, Op>(ty);
   ```
   
   </details>
   
   By the way, I note on the JIRA that we want to return NaN if only one NaN 
input is provided. Is this not already the case?
   
   ```python
   >>> pc.min_element_wise(pa.scalar(float('nan')))
   <pyarrow.DoubleScalar: nan>
   >>> pc.min_element_wise([float('nan')])
   <pyarrow.lib.DoubleArray object at 0x7f1693ad7e80>
   [
     nan
   ]
   >>> pc.min_element_wise([float('nan')], pa.array([None], 'float32'))
   <pyarrow.lib.DoubleArray object at 0x7f1693b40160>
   [
     nan
   ]
   >>> pc.min_element_wise(pa.scalar(float('nan')), pa.scalar(None, 'float32'))
   <pyarrow.DoubleScalar: nan>
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to