edponce commented on a change in pull request #11023:
URL: https://github.com/apache/arrow/pull/11023#discussion_r731389134
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -417,6 +419,231 @@ struct StringTransformExecWithState
}
};
+struct StringBinaryTransformBase {
+ virtual Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum*
out) {
+ return Status::OK();
+ }
+
+ // Return the maximum total size of the output in codeunits (i.e. bytes)
+ // given input characteristics.
+ virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits,
+ const std::shared_ptr<Scalar>& input2) {
+ return input_ncodeunits;
+ }
+
+ // Return the maximum total size of the output in codeunits (i.e. bytes)
+ // given input characteristics.
+ virtual int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits,
+ const std::shared_ptr<ArrayData>& data2) {
+ return input_ncodeunits;
+ }
+
+ virtual Status InvalidStatus() {
+ return Status::Invalid("Invalid UTF8 sequence in input");
+ }
+};
+
+/// Kernel exec generator for binary string transforms.
+/// The first parameter is expected to always be a string type while the
second parameter
+/// is generic. It supports executions of the form:
+/// * Scalar, Scalar
+/// * Array, Scalar - scalar is broadcasted and paired with all values of
array
+/// * Array, Array - arrays are processed element-wise
+/// * Scalar, Array - not supported by default
+template <typename Type1, typename Type2, typename StringTransform>
+struct StringBinaryTransformExecBase {
+ using offset_type = typename Type1::offset_type;
+ using ArrayType1 = typename TypeTraits<Type1>::ArrayType;
+ using ArrayType2 = typename TypeTraits<Type2>::ArrayType;
+
+ static Status Execute(KernelContext* ctx, StringTransform* transform,
+ const ExecBatch& batch, Datum* out) {
+ if (batch.num_values() != 2) {
+ return Status::Invalid("Invalid arity for binary string transform");
+ }
+
+ if (batch[0].is_array()) {
+ if (batch[1].is_array()) {
+ return ExecArrayArray(ctx, transform, batch[0].array(),
batch[1].array(), out);
+ } else if (batch[1].is_scalar()) {
+ return ExecArrayScalar(ctx, transform, batch[0].array(),
batch[1].scalar(), out);
+ }
+ } else if (batch[0].is_scalar()) {
+ if (batch[1].is_array()) {
+ return ExecScalarArray(ctx, transform, batch[0].scalar(),
batch[1].array(), out);
+ } else if (batch[1].is_scalar()) {
+ return ExecScalarScalar(ctx, transform, batch[0].scalar(),
batch[1].scalar(),
+ out);
+ }
+ }
+ return Status::Invalid("Invalid ExecBatch kind for binary string
transform");
+ }
+
+ static Status ExecScalarScalar(KernelContext* ctx, StringTransform*
transform,
+ const std::shared_ptr<Scalar>& scalar1,
+ const std::shared_ptr<Scalar>& scalar2,
Datum* out) {
+ if (!scalar1->is_valid || !scalar2->is_valid) {
+ return Status::OK();
+ }
+
+ const auto& input1 = checked_cast<const BaseBinaryScalar&>(*scalar1);
+ auto input_ncodeunits = input1.value->size();
+ auto input_nstrings = 1;
+ auto output_ncodeunits_max =
+ transform->MaxCodeunits(input_nstrings, input_ncodeunits, scalar2);
+ if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
+ return Status::CapacityError(
+ "Result might not fit in a 32bit utf8 array, convert to large_utf8");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto value_buffer,
ctx->Allocate(output_ncodeunits_max));
+ auto result = checked_cast<BaseBinaryScalar*>(out->scalar().get());
+ result->is_valid = true;
+ result->value = value_buffer;
+ auto output_str = value_buffer->mutable_data();
+
+ auto input1_string = input1.value->data();
+ auto encoded_nbytes = static_cast<offset_type>(
+ transform->Transform(input1_string, input_ncodeunits, scalar2,
output_str));
+ if (encoded_nbytes < 0) {
+ return transform->InvalidStatus();
+ }
+ DCHECK_LE(encoded_nbytes, output_ncodeunits_max);
+ return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true);
+ }
+
+ static Status ExecArrayScalar(KernelContext* ctx, StringTransform* transform,
+ const std::shared_ptr<ArrayData>& data1,
+ const std::shared_ptr<Scalar>& scalar2, Datum*
out) {
+ if (!scalar2->is_valid) {
+ return Status::OK();
+ }
+
+ ArrayType1 input1(data1);
+ auto input1_ncodeunits = input1.total_values_length();
+ auto input1_nstrings = input1.length();
+ auto output_ncodeunits_max =
+ transform->MaxCodeunits(input1_nstrings, input1_ncodeunits, scalar2);
+ if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
+ return Status::CapacityError(
+ "Result might not fit in a 32bit utf8 array, convert to large_utf8");
+ }
+
+ ArrayData* output = out->mutable_array();
+ ARROW_ASSIGN_OR_RAISE(auto values_buffer,
ctx->Allocate(output_ncodeunits_max));
+ output->buffers[2] = values_buffer;
+
+ // String offsets are preallocated
+ auto output_string_offsets = output->GetMutableValues<offset_type>(1);
+ auto output_str = output->buffers[2]->mutable_data();
+ output_string_offsets[0] = 0;
+
+ offset_type output_ncodeunits = 0;
+ for (int64_t i = 0; i < input1_nstrings; ++i) {
+ if (!input1.IsNull(i)) {
+ offset_type input1_string_ncodeunits;
+ auto input1_string = input1.GetValue(i, &input1_string_ncodeunits);
+ auto encoded_nbytes = static_cast<offset_type>(
+ transform->Transform(input1_string, input1_string_ncodeunits,
scalar2,
+ output_str + output_ncodeunits));
+ if (encoded_nbytes < 0) {
+ return transform->InvalidStatus();
+ }
+ output_ncodeunits += encoded_nbytes;
+ }
+ output_string_offsets[i + 1] = output_ncodeunits;
+ }
+ DCHECK_LE(output_ncodeunits, output_ncodeunits_max);
+
+ // Trim the codepoint buffer, since we allocated too much
+ return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true);
+ return Status::OK();
+ }
+
+ static Status ExecScalarArray(KernelContext* ctx, StringTransform* transform,
+ const std::shared_ptr<Scalar>& scalar1,
+ const std::shared_ptr<ArrayData>& data2,
Datum* out) {
+ return Status::NotImplemented(
+ "Binary string transforms with (scalar, array) inputs are not
supported for the "
+ "general case");
+ }
+
+ static Status ExecArrayArray(KernelContext* ctx, StringTransform* transform,
+ const std::shared_ptr<ArrayData>& data1,
+ const std::shared_ptr<ArrayData>& data2, Datum*
out) {
+ ArrayType1 input1(data1);
+ ArrayType2 input2(data2);
+
+ auto input1_ncodeunits = input1.total_values_length();
+ auto input1_nstrings = input1.length();
+ auto output_ncodeunits_max =
+ transform->MaxCodeunits(input1_nstrings, input1_ncodeunits, data2);
+ if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
+ return Status::CapacityError(
+ "Result might not fit in a 32bit utf8 array, convert to large_utf8");
+ }
+
+ ArrayData* output = out->mutable_array();
+ ARROW_ASSIGN_OR_RAISE(auto values_buffer,
ctx->Allocate(output_ncodeunits_max));
+ output->buffers[2] = values_buffer;
+
+ // String offsets are preallocated
+ auto output_string_offsets = output->GetMutableValues<offset_type>(1);
+ auto output_str = output->buffers[2]->mutable_data();
+ output_string_offsets[0] = 0;
+
+ offset_type output_ncodeunits = 0;
+ for (int64_t i = 0; i < input1_nstrings; ++i) {
+ if (!input1.IsNull(i) || !input2.IsNull(i)) {
Review comment:
It seems we are not able to use the `VisitBitBlocks` utilities because
the current `StringBinaryTransformExecBase` implementation when processing
`Array` needs to set the output string offsets (`output_string_offsets`) when
traversing both non-null and null positions, and this requires the `position`
being visited for both visitors.
```c++
offset_type output_ncodeunits = 0;
for (i = 0...) {
if (!input1.IsNull(i)) {
...
offset_type encoded_bytes = Transform(...);
...
output_ncodeunits += encoded_bytes;
}
// This needs to be updated for Null/NotNull visitors
output_string_offsets[i + 1] = output_ncodeunits;
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]