westonpace commented on a change in pull request #11023:
URL: https://github.com/apache/arrow/pull/11023#discussion_r738333645
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -402,16 +401,16 @@ struct StringTransformExecBase {
if (!input.is_valid) {
return Status::OK();
}
- auto* result = checked_cast<BaseBinaryScalar*>(out->scalar().get());
- result->is_valid = true;
const int64_t data_nbytes = static_cast<int64_t>(input.value->size());
-
const int64_t output_ncodeunits_max = transform->MaxCodeunits(1,
data_nbytes);
if (output_ncodeunits_max > std::numeric_limits<offset_type>::max()) {
return Status::CapacityError(
"Result might not fit in a 32bit utf8 array, convert to large_utf8");
}
+
ARROW_ASSIGN_OR_RAISE(auto value_buffer,
ctx->Allocate(output_ncodeunits_max));
+ auto* result = checked_cast<BaseBinaryScalar*>(out->scalar().get());
Review comment:
Nit: You don't really have to, and I don't know how this could possibly
not work, but I'm in the habit of adding `DCHECK_NE(result, nullptr);`after any
`checked_cast`.
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -537,6 +536,341 @@ struct FixedSizeBinaryTransformExecWithState
}
};
+template <typename Type1, typename Type2>
+struct StringBinaryTransformBase {
+ using ViewType2 = typename GetViewType<Type2>::T;
+ using ArrayType1 = typename TypeTraits<Type1>::ArrayType;
+ using ArrayType2 = typename TypeTraits<Type2>::ArrayType;
+
+ virtual ~StringBinaryTransformBase() = default;
+
+ virtual Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum*
out) {
+ return Status::OK();
+ }
+
+ virtual Status InvalidStatus() {
+ return Status::Invalid("Invalid UTF8 sequence in input");
+ }
+
+ // Return the maximum total size of the output in codeunits (i.e. bytes)
+ // given input characteristics for different input shapes.
+ //
+ // Scalar-Scalar
+ virtual int64_t MaxCodeunits(const int64_t input1_ncodeunits, const
ViewType2) {
+ return input1_ncodeunits;
+ }
+
+ // Scalar-Array
+ virtual int64_t MaxCodeunits(const int64_t input1_ncodeunits, const
ArrayType2&) {
+ return input1_ncodeunits;
+ }
+
+ // Array-Scalar
+ virtual int64_t MaxCodeunits(const ArrayType1& input1, const ViewType2) {
+ return input1.total_values_length();
+ }
+
+ // Array-Array
+ virtual int64_t MaxCodeunits(const ArrayType1& input1, const ArrayType2&) {
+ return input1.total_values_length();
+ }
+
+ // Not all combinations of input shapes are meaningful to string binary
+ // transforms, so these flags serve as control toggles for enabling/disabling
+ // the corresponding ones. These flags should be set in the PreExec() method.
+ //
+ // This is an example of a StringTransform that disables argument shapes with
+ // mixed scalar/array.
+ //
+ // template <typename Type1, typename Type2>
+ // struct MyStringTransform : public StringBinaryTransformBase<Type1, Type2>
{
+ // Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out)
override {
+ // EnableScalarArray = false;
+ // EnableArrayScalar = false;
+ // return StringBinaryTransformBase::PreExec(ctx, batch, out);
+ // }
+ // ...
+ // };
+ bool EnableScalarScalar = true;
+ bool EnableScalarArray = true;
+ bool EnableArrayScalar = true;
+ bool EnableArrayArray = true;
+
+ // Tracks status of transform in StringBinaryTransformExecBase.
+ // The purpose of this transform status is to provide a means to
report/detect
+ // errors in functions that do not provide a mechanism to return a Status
+ // value but can still detect errors. This status is checked automatically
+ // after MaxCodeunits() and Transform() operations.
+ Status st = Status::OK();
+};
+
+/// Kernel exec generator for binary (two parameters) string transforms.
+/// The first parameter is expected to always be a Binary/StringType while the
+/// second parameter is generic. Types of template parameter StringTransform
+/// need to define a transform method with the following signature:
+///
+/// int64_t Transform(const uint8_t* input, const int64_t
input_string_ncodeunits,
+/// const ViewType2 value2, uint8_t* output);
Review comment:
Ah, nevermind, I see now that it is overridden if the kernel is
lengthening the string.
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -537,6 +536,341 @@ struct FixedSizeBinaryTransformExecWithState
}
};
+template <typename Type1, typename Type2>
+struct StringBinaryTransformBase {
+ using ViewType2 = typename GetViewType<Type2>::T;
+ using ArrayType1 = typename TypeTraits<Type1>::ArrayType;
+ using ArrayType2 = typename TypeTraits<Type2>::ArrayType;
+
+ virtual ~StringBinaryTransformBase() = default;
+
+ virtual Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum*
out) {
+ return Status::OK();
+ }
+
+ virtual Status InvalidStatus() {
+ return Status::Invalid("Invalid UTF8 sequence in input");
+ }
+
+ // Return the maximum total size of the output in codeunits (i.e. bytes)
+ // given input characteristics for different input shapes.
+ //
+ // Scalar-Scalar
+ virtual int64_t MaxCodeunits(const int64_t input1_ncodeunits, const
ViewType2) {
+ return input1_ncodeunits;
+ }
+
+ // Scalar-Array
+ virtual int64_t MaxCodeunits(const int64_t input1_ncodeunits, const
ArrayType2&) {
+ return input1_ncodeunits;
+ }
+
+ // Array-Scalar
+ virtual int64_t MaxCodeunits(const ArrayType1& input1, const ViewType2) {
+ return input1.total_values_length();
+ }
+
+ // Array-Array
+ virtual int64_t MaxCodeunits(const ArrayType1& input1, const ArrayType2&) {
+ return input1.total_values_length();
+ }
+
+ // Not all combinations of input shapes are meaningful to string binary
+ // transforms, so these flags serve as control toggles for enabling/disabling
+ // the corresponding ones. These flags should be set in the PreExec() method.
+ //
+ // This is an example of a StringTransform that disables argument shapes with
+ // mixed scalar/array.
+ //
+ // template <typename Type1, typename Type2>
+ // struct MyStringTransform : public StringBinaryTransformBase<Type1, Type2>
{
+ // Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out)
override {
+ // EnableScalarArray = false;
+ // EnableArrayScalar = false;
+ // return StringBinaryTransformBase::PreExec(ctx, batch, out);
+ // }
+ // ...
+ // };
+ bool EnableScalarScalar = true;
+ bool EnableScalarArray = true;
+ bool EnableArrayScalar = true;
+ bool EnableArrayArray = true;
+
+ // Tracks status of transform in StringBinaryTransformExecBase.
+ // The purpose of this transform status is to provide a means to
report/detect
+ // errors in functions that do not provide a mechanism to return a Status
+ // value but can still detect errors. This status is checked automatically
+ // after MaxCodeunits() and Transform() operations.
+ Status st = Status::OK();
+};
+
+/// Kernel exec generator for binary (two parameters) string transforms.
+/// The first parameter is expected to always be a Binary/StringType while the
+/// second parameter is generic. Types of template parameter StringTransform
+/// need to define a transform method with the following signature:
+///
+/// int64_t Transform(const uint8_t* input, const int64_t
input_string_ncodeunits,
+/// const ViewType2 value2, uint8_t* output);
Review comment:
This may just be my ignorance with these kind of generators but it isn't
obvious to me what the return value represents?
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -537,6 +536,341 @@ struct FixedSizeBinaryTransformExecWithState
}
};
+template <typename Type1, typename Type2>
+struct StringBinaryTransformBase {
+ using ViewType2 = typename GetViewType<Type2>::T;
+ using ArrayType1 = typename TypeTraits<Type1>::ArrayType;
+ using ArrayType2 = typename TypeTraits<Type2>::ArrayType;
+
+ virtual ~StringBinaryTransformBase() = default;
+
+ virtual Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum*
out) {
+ return Status::OK();
+ }
+
+ virtual Status InvalidStatus() {
+ return Status::Invalid("Invalid UTF8 sequence in input");
+ }
+
+ // Return the maximum total size of the output in codeunits (i.e. bytes)
+ // given input characteristics for different input shapes.
+ //
+ // Scalar-Scalar
+ virtual int64_t MaxCodeunits(const int64_t input1_ncodeunits, const
ViewType2) {
+ return input1_ncodeunits;
+ }
+
+ // Scalar-Array
+ virtual int64_t MaxCodeunits(const int64_t input1_ncodeunits, const
ArrayType2&) {
+ return input1_ncodeunits;
+ }
+
+ // Array-Scalar
+ virtual int64_t MaxCodeunits(const ArrayType1& input1, const ViewType2) {
+ return input1.total_values_length();
+ }
+
+ // Array-Array
+ virtual int64_t MaxCodeunits(const ArrayType1& input1, const ArrayType2&) {
+ return input1.total_values_length();
+ }
+
+ // Not all combinations of input shapes are meaningful to string binary
+ // transforms, so these flags serve as control toggles for enabling/disabling
+ // the corresponding ones. These flags should be set in the PreExec() method.
+ //
+ // This is an example of a StringTransform that disables argument shapes with
+ // mixed scalar/array.
+ //
+ // template <typename Type1, typename Type2>
+ // struct MyStringTransform : public StringBinaryTransformBase<Type1, Type2>
{
+ // Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out)
override {
+ // EnableScalarArray = false;
+ // EnableArrayScalar = false;
+ // return StringBinaryTransformBase::PreExec(ctx, batch, out);
+ // }
+ // ...
+ // };
+ bool EnableScalarScalar = true;
+ bool EnableScalarArray = true;
+ bool EnableArrayScalar = true;
+ bool EnableArrayArray = true;
+
+ // Tracks status of transform in StringBinaryTransformExecBase.
+ // The purpose of this transform status is to provide a means to
report/detect
+ // errors in functions that do not provide a mechanism to return a Status
+ // value but can still detect errors. This status is checked automatically
+ // after MaxCodeunits() and Transform() operations.
+ Status st = Status::OK();
+};
+
+/// Kernel exec generator for binary (two parameters) string transforms.
+/// The first parameter is expected to always be a Binary/StringType while the
+/// second parameter is generic. Types of template parameter StringTransform
+/// need to define a transform method with the following signature:
+///
+/// int64_t Transform(const uint8_t* input, const int64_t
input_string_ncodeunits,
+/// const ViewType2 value2, uint8_t* output);
+template <typename Type1, typename Type2, typename StringTransform>
+struct StringBinaryTransformExecBase {
+ using offset_type = typename Type1::offset_type;
+ using ViewType2 = typename GetViewType<Type2>::T;
+ using ArrayType1 = typename TypeTraits<Type1>::ArrayType;
+ using ArrayType2 = typename TypeTraits<Type2>::ArrayType;
+
+ static Status Execute(KernelContext* ctx, StringTransform* transform,
+ const ExecBatch& batch, Datum* out) {
+ if (batch[0].is_scalar()) {
+ if (batch[1].is_scalar()) {
+ if (transform->EnableScalarScalar) {
+ return ExecScalarScalar(ctx, transform, batch[0].scalar(),
batch[1].scalar(),
+ out);
+ }
+ } else if (batch[1].is_array()) {
+ if (transform->EnableScalarArray) {
+ return ExecScalarArray(ctx, transform, batch[0].scalar(),
batch[1].array(),
+ out);
+ }
+ }
+ } else if (batch[0].is_array()) {
+ if (batch[1].is_array()) {
+ if (transform->EnableArrayArray) {
+ return ExecArrayArray(ctx, transform, batch[0].array(),
batch[1].array(), out);
+ }
+ } else if (batch[1].is_scalar()) {
+ if (transform->EnableArrayScalar) {
+ return ExecArrayScalar(ctx, transform, batch[0].array(),
batch[1].scalar(),
+ out);
+ }
+ }
+ }
+ return Status::Invalid("Invalid ExecBatch kind for binary string
transform");
+ }
+
+ static Status ExecScalarScalar(KernelContext* ctx, StringTransform*
transform,
+ const std::shared_ptr<Scalar>& scalar1,
+ const std::shared_ptr<Scalar>& scalar2,
Datum* out) {
+ if (!scalar1->is_valid || !scalar2->is_valid) {
+ return Status::OK();
+ }
+ const auto& binary_scalar1 = checked_cast<const
BaseBinaryScalar&>(*scalar1);
+ const auto input_string = binary_scalar1.value->data();
+ const auto input_ncodeunits = binary_scalar1.value->size();
+ const auto value2 = UnboxScalar<Type2>::Unbox(*scalar2);
+
+ // Calculate max number of output codeunits
+ const auto max_output_ncodeunits =
transform->MaxCodeunits(input_ncodeunits, value2);
Review comment:
This is more a learning question for me but how can you calculate the
output size based on the input? Don't some kernels output strings that are
longer than the input?
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -537,6 +536,341 @@ struct FixedSizeBinaryTransformExecWithState
}
};
+template <typename Type1, typename Type2>
+struct StringBinaryTransformBase {
+ using ViewType2 = typename GetViewType<Type2>::T;
+ using ArrayType1 = typename TypeTraits<Type1>::ArrayType;
+ using ArrayType2 = typename TypeTraits<Type2>::ArrayType;
+
+ virtual ~StringBinaryTransformBase() = default;
+
+ virtual Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum*
out) {
+ return Status::OK();
+ }
+
+ virtual Status InvalidStatus() {
+ return Status::Invalid("Invalid UTF8 sequence in input");
+ }
Review comment:
Minor nit: Maybe name this `InvalidUtf8Sequence()`
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -537,6 +536,341 @@ struct FixedSizeBinaryTransformExecWithState
}
};
+template <typename Type1, typename Type2>
+struct StringBinaryTransformBase {
+ using ViewType2 = typename GetViewType<Type2>::T;
+ using ArrayType1 = typename TypeTraits<Type1>::ArrayType;
+ using ArrayType2 = typename TypeTraits<Type2>::ArrayType;
+
+ virtual ~StringBinaryTransformBase() = default;
+
+ virtual Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum*
out) {
+ return Status::OK();
+ }
+
+ virtual Status InvalidStatus() {
+ return Status::Invalid("Invalid UTF8 sequence in input");
+ }
+
+ // Return the maximum total size of the output in codeunits (i.e. bytes)
+ // given input characteristics for different input shapes.
+ //
+ // Scalar-Scalar
+ virtual int64_t MaxCodeunits(const int64_t input1_ncodeunits, const
ViewType2) {
+ return input1_ncodeunits;
+ }
+
+ // Scalar-Array
+ virtual int64_t MaxCodeunits(const int64_t input1_ncodeunits, const
ArrayType2&) {
+ return input1_ncodeunits;
+ }
+
+ // Array-Scalar
+ virtual int64_t MaxCodeunits(const ArrayType1& input1, const ViewType2) {
+ return input1.total_values_length();
+ }
+
+ // Array-Array
+ virtual int64_t MaxCodeunits(const ArrayType1& input1, const ArrayType2&) {
+ return input1.total_values_length();
+ }
+
+ // Not all combinations of input shapes are meaningful to string binary
+ // transforms, so these flags serve as control toggles for enabling/disabling
+ // the corresponding ones. These flags should be set in the PreExec() method.
+ //
+ // This is an example of a StringTransform that disables argument shapes with
+ // mixed scalar/array.
+ //
+ // template <typename Type1, typename Type2>
+ // struct MyStringTransform : public StringBinaryTransformBase<Type1, Type2>
{
+ // Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out)
override {
+ // EnableScalarArray = false;
+ // EnableArrayScalar = false;
+ // return StringBinaryTransformBase::PreExec(ctx, batch, out);
+ // }
+ // ...
+ // };
+ bool EnableScalarScalar = true;
+ bool EnableScalarArray = true;
+ bool EnableArrayScalar = true;
+ bool EnableArrayArray = true;
+
+ // Tracks status of transform in StringBinaryTransformExecBase.
+ // The purpose of this transform status is to provide a means to
report/detect
+ // errors in functions that do not provide a mechanism to return a Status
+ // value but can still detect errors. This status is checked automatically
+ // after MaxCodeunits() and Transform() operations.
+ Status st = Status::OK();
+};
+
+/// Kernel exec generator for binary (two parameters) string transforms.
+/// The first parameter is expected to always be a Binary/StringType while the
+/// second parameter is generic. Types of template parameter StringTransform
+/// need to define a transform method with the following signature:
+///
+/// int64_t Transform(const uint8_t* input, const int64_t
input_string_ncodeunits,
+/// const ViewType2 value2, uint8_t* output);
+template <typename Type1, typename Type2, typename StringTransform>
+struct StringBinaryTransformExecBase {
+ using offset_type = typename Type1::offset_type;
+ using ViewType2 = typename GetViewType<Type2>::T;
+ using ArrayType1 = typename TypeTraits<Type1>::ArrayType;
+ using ArrayType2 = typename TypeTraits<Type2>::ArrayType;
+
+ static Status Execute(KernelContext* ctx, StringTransform* transform,
+ const ExecBatch& batch, Datum* out) {
+ if (batch[0].is_scalar()) {
+ if (batch[1].is_scalar()) {
+ if (transform->EnableScalarScalar) {
+ return ExecScalarScalar(ctx, transform, batch[0].scalar(),
batch[1].scalar(),
+ out);
+ }
+ } else if (batch[1].is_array()) {
+ if (transform->EnableScalarArray) {
+ return ExecScalarArray(ctx, transform, batch[0].scalar(),
batch[1].array(),
+ out);
+ }
+ }
+ } else if (batch[0].is_array()) {
+ if (batch[1].is_array()) {
+ if (transform->EnableArrayArray) {
+ return ExecArrayArray(ctx, transform, batch[0].array(),
batch[1].array(), out);
+ }
+ } else if (batch[1].is_scalar()) {
+ if (transform->EnableArrayScalar) {
+ return ExecArrayScalar(ctx, transform, batch[0].array(),
batch[1].scalar(),
+ out);
+ }
+ }
+ }
+ return Status::Invalid("Invalid ExecBatch kind for binary string
transform");
Review comment:
This probably won't happen too often but could you inlcude the two
"kinds" that led to this? For example, it would be much clearer to the user to
see "Invalid combination of operands for binary string transform fn-name
(array, scalar)"
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]