pitrou commented on a change in pull request #11218:
URL: https://github.com/apache/arrow/pull/11218#discussion_r719433410
##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc
##########
@@ -989,23 +1123,41 @@ struct IfElseFunction : ScalarFunction {
RETURN_NOT_OK(CheckArity(*values));
using arrow::compute::detail::DispatchExactImpl;
- if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ // Do not DispatchExact here because it'll let through something like
(bool,
+ // timestamp[s], timestamp[s, "UTC"])
// if 0th descriptor is null, replace with bool
if (values->at(0).type->id() == Type::NA) {
values->at(0).type = boolean();
}
// if-else 0'th descriptor is bool, so skip it
- std::vector<ValueDescr> values_copy(values->begin() + 1, values->end());
- internal::EnsureDictionaryDecoded(&values_copy);
- internal::ReplaceNullWithOtherType(&values_copy);
+ ValueDescr* left_arg = &(*values)[1];
+ constexpr size_t num_args = 2;
- if (auto type = internal::CommonNumeric(values_copy)) {
- internal::ReplaceTypes(type, &values_copy);
+ internal::ReplaceNullWithOtherType(left_arg, num_args);
+
+ if (is_dictionary((*values)[1].type->id()) &&
+ (*values)[1].type->Equals(*(*values)[2].type)) {
Review comment:
Hmm, what about the other way round where the second type is a dict of
the first type?
Can you add a comment?
##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc
##########
@@ -900,44 +896,182 @@ struct IfElseFunctor<Type,
enable_if_fixed_size_binary<Type>> {
auto* out_values = out->buffers[1]->mutable_data() + out->offset *
byte_width;
// copy right data to out_buff
- const util::string_view& right_data =
- internal::UnboxScalar<FixedSizeBinaryType>::Unbox(right);
- if (right_data.data()) {
+ const uint8_t* right_data = UnboxBinaryScalar(right);
+ if (right_data) {
for (int64_t i = 0; i < cond.length; i++) {
- std::memcpy(out_values + i * byte_width, right_data.data(),
right_data.size());
+ std::memcpy(out_values + i * byte_width, right_data, byte_width);
}
}
// selectively copy values from left data
- const util::string_view& left_data =
- internal::UnboxScalar<FixedSizeBinaryType>::Unbox(left);
-
+ const uint8_t* left_data = UnboxBinaryScalar(left);
RunIfElseLoop(cond, [&](int64_t data_offset, int64_t num_elems) {
- if (left_data.data()) {
+ if (left_data) {
for (int64_t i = 0; i < num_elems; i++) {
- std::memcpy(out_values + (data_offset + i) * byte_width,
left_data.data(),
- left_data.size());
+ std::memcpy(out_values + (data_offset + i) * byte_width, left_data,
byte_width);
}
}
});
return Status::OK();
}
- static Result<int32_t> GetByteWidth(const DataType& left_type,
- const DataType& right_type) {
- int width = checked_cast<const
FixedSizeBinaryType&>(left_type).byte_width();
- if (width == checked_cast<const
FixedSizeBinaryType&>(right_type).byte_width()) {
- return width;
+ template <typename T = Type>
+ static enable_if_t<!is_decimal_type<T>::value, const uint8_t*>
UnboxBinaryScalar(
+ const Scalar& scalar) {
+ return reinterpret_cast<const uint8_t*>(
+ internal::UnboxScalar<FixedSizeBinaryType>::Unbox(scalar).data());
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal<T, const uint8_t*> UnboxBinaryScalar(const Scalar&
scalar) {
+ return internal::UnboxScalar<T>::Unbox(scalar).native_endian_bytes();
+ }
+
+ template <typename T = Type>
+ static enable_if_t<!is_decimal_type<T>::value, Result<int32_t>> GetByteWidth(
+ const DataType& left_type, const DataType& right_type) {
+ const int32_t width =
+ checked_cast<const FixedSizeBinaryType&>(left_type).byte_width();
+ DCHECK_EQ(width, checked_cast<const
FixedSizeBinaryType&>(right_type).byte_width());
+ return width;
+ }
+
+ template <typename T = Type>
+ static enable_if_decimal<T, Result<int32_t>> GetByteWidth(const DataType&
left_type,
+ const DataType&
right_type) {
+ const auto& left = checked_cast<const T&>(left_type);
+ const auto& right = checked_cast<const T&>(right_type);
+ DCHECK_EQ(left.precision(), right.precision());
+ DCHECK_EQ(left.scale(), right.scale());
+ return left.byte_width();
+ }
+};
+
+// Use builders for dictionaries - slower, but allows us to unify dictionaries
+template <typename Type>
+struct IfElseFunctor<
+ Type, enable_if_t<is_nested_type<Type>::value ||
is_dictionary_type<Type>::value>> {
+ // A - Array, S - Scalar, X = Array/Scalar
+
+ // SXX
+ static Status Call(KernelContext* ctx, const BooleanScalar& cond, const
Datum& left,
+ const Datum& right, Datum* out) {
+ if (left.is_scalar() && right.is_scalar()) {
+ if (cond.is_valid) {
+ *out = cond.value ? left.scalar() : right.scalar();
+ } else {
+ *out = MakeNullScalar(left.type());
+ }
+ return Status::OK();
+ }
+ // either left or right is an array. Output is always an array
+ int64_t out_arr_len = std::max(left.length(), right.length());
+ if (!cond.is_valid) {
+ // cond is null; just create a null array
+ ARROW_ASSIGN_OR_RAISE(*out,
+ MakeArrayOfNull(left.type(), out_arr_len,
ctx->memory_pool()))
+ return Status::OK();
+ }
+
+ const auto& valid_data = cond.value ? left : right;
+ if (valid_data.is_array()) {
+ *out = valid_data;
} else {
- return Status::Invalid("FixedSizeBinaryType byte_widths should be
equal");
+ // valid data is a scalar that needs to be broadcasted
+ ARROW_ASSIGN_OR_RAISE(*out, MakeArrayFromScalar(*valid_data.scalar(),
out_arr_len,
+ ctx->memory_pool()));
}
+ return Status::OK();
+ }
+
+ // AAA
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const
ArrayData& left,
+ const ArrayData& right, ArrayData* out) {
+ return RunLoop(
Review comment:
It looks like this could be much more efficient if `RunLoop` batched
runs of 0s and 1s.
##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc
##########
@@ -1031,7 +1183,16 @@ void AddPrimitiveIfElseKernels(const
std::shared_ptr<ScalarFunction>& scalar_fun
internal::GenerateTypeAgnosticPrimitive<ResolveIfElseExec,
/*AllocateMem=*/std::false_type>(*type);
// cond array needs to be boolean always
- ScalarKernel kernel({boolean(), type, type}, type, exec);
+ std::shared_ptr<KernelSignature> sig;
+ if (type->id() == Type::TIMESTAMP) {
+ auto unit = checked_cast<const TimestampType&>(*type).unit();
+ sig = KernelSignature::Make(
+ {boolean(), match::TimestampTypeUnit(unit),
match::TimestampTypeUnit(unit)},
Review comment:
Wouldn't `CheckIdenticalTypes` later catch the differing units?
(also, the timezones should also be the same?)
##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
##########
@@ -140,12 +157,27 @@ void CheckWithDifferentShapes(const
std::shared_ptr<Array>& cond,
"@" + std::to_string(right_idx) + "=" +
right_in.scalar()->ToString();
} else {
right_in = right_bcast = right;
+ trace_right += "=Array";
}
SCOPED_TRACE(trace_right);
- ASSERT_OK_AND_ASSIGN(auto exp, IfElse(cond_bcast, left_bcast,
right_bcast));
+ Datum expected;
ASSERT_OK_AND_ASSIGN(auto actual, IfElse(cond_in, left_in,
right_in));
- AssertDatumsEqual(exp, actual, /*verbose=*/true);
+ if (mask & COND_SCALAR && mask & LEFT_SCALAR && mask & RIGHT_SCALAR)
{
Review comment:
Add parentheses to make this more readable, or `mask == (COND_SCALAR |
LEFT_SCALAR | RIGHT_SCALAR)`?
##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else.cc
##########
@@ -989,23 +1123,41 @@ struct IfElseFunction : ScalarFunction {
RETURN_NOT_OK(CheckArity(*values));
using arrow::compute::detail::DispatchExactImpl;
- if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ // Do not DispatchExact here because it'll let through something like
(bool,
+ // timestamp[s], timestamp[s, "UTC"])
// if 0th descriptor is null, replace with bool
if (values->at(0).type->id() == Type::NA) {
values->at(0).type = boolean();
}
// if-else 0'th descriptor is bool, so skip it
- std::vector<ValueDescr> values_copy(values->begin() + 1, values->end());
- internal::EnsureDictionaryDecoded(&values_copy);
- internal::ReplaceNullWithOtherType(&values_copy);
+ ValueDescr* left_arg = &(*values)[1];
+ constexpr size_t num_args = 2;
- if (auto type = internal::CommonNumeric(values_copy)) {
- internal::ReplaceTypes(type, &values_copy);
+ internal::ReplaceNullWithOtherType(left_arg, num_args);
+
+ if (is_dictionary((*values)[1].type->id()) &&
+ (*values)[1].type->Equals(*(*values)[2].type)) {
+ auto kernel = DispatchExactImpl(this, *values);
+ DCHECK(kernel);
+ return kernel;
}
- std::move(values_copy.begin(), values_copy.end(), values->begin() + 1);
+ internal::EnsureDictionaryDecoded(left_arg, num_args);
+
+ if (auto type = internal::CommonNumeric(left_arg, num_args)) {
+ internal::ReplaceTypes(type, left_arg, num_args);
+ }
+ if (auto type = internal::CommonTemporal(left_arg, num_args)) {
Review comment:
Shouldn't this be a chain of "else if"?
##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
##########
@@ -103,7 +118,7 @@ void CheckWithDifferentShapes(const std::shared_ptr<Array>&
cond,
auto len = left->length();
enum { COND_SCALAR = 1, LEFT_SCALAR = 2, RIGHT_SCALAR = 4 };
- for (int mask = 0; mask < (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR);
++mask) {
+ for (int mask = 1; mask <= (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR);
++mask) {
for (int64_t cond_idx = 0; cond_idx < len; ++cond_idx) {
Review comment:
Speaking of which, if `mask & COND_SCALAR` is false, this will loop
`len` times doing the same thing?
##########
File path: cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
##########
@@ -103,7 +118,7 @@ void CheckWithDifferentShapes(const std::shared_ptr<Array>&
cond,
auto len = left->length();
enum { COND_SCALAR = 1, LEFT_SCALAR = 2, RIGHT_SCALAR = 4 };
- for (int mask = 0; mask < (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR);
++mask) {
+ for (int mask = 1; mask <= (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR);
++mask) {
for (int64_t cond_idx = 0; cond_idx < len; ++cond_idx) {
Review comment:
You could avoid that e.g.:
```c++
std::vector<int64_t> array_indices{-1}; // sentinel
std::vector<int64_t> scalar_indices(len);
std::iota(scalar_indices.begin(), scalar_indices.end(), 0);
auto make_input = [&](const std::shared_ptr<Array>& array, int64_t index,
Datum* input, Datum* input_bcast, std::string* trace) {
if (index >= 0) {
// Use scalar as input, broadcast it for expected result computation
ASSERT_OK_AND_ASSIGN(*input, array->GetScalar(index));
ASSERT_OK_AND_ASSIGN(*input_bcast,
MakeArrayFromScalar(*input->scalar(), len));
*trace += "...":
} else {
// Use array as input
*input = *input_bcast = array;
*trace += "=Array";
}
};
for (int mask = 1; mask <= (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR);
++mask) {
for (const auto cond_idx : (mask & COND_SCALAR) ? scalar_indices :
array_indices) {
Datum cond_in, cond_bcast;
std::string trace_cond = "Cond";
make_input(cond, cond_idx, &cond_in, &cond_bcast, &trace_cond);
for (const auto left_idx : (mask & LEFT_SCALAR) ? scalar_indices :
array_indices) {
// etc.
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]