edponce commented on a change in pull request #11023:
URL: https://github.com/apache/arrow/pull/11023#discussion_r740687132
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -2513,6 +2877,159 @@ void AddSplit(FunctionRegistry* registry) {
#endif
}
+/// An ScalarFunction that promotes integer arguments to Int64.
+struct ScalarCTypeToInt64Function : public ScalarFunction {
+ using ScalarFunction::ScalarFunction;
+
+ Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const
override {
+ RETURN_NOT_OK(CheckArity(*values));
+
+ using arrow::compute::detail::DispatchExactImpl;
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+
+ EnsureDictionaryDecoded(values);
+
+ for (auto& descr : *values) {
+ if (is_integer(descr.type->id())) {
+ descr.type = int64();
+ }
+ }
+
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ return arrow::compute::detail::NoMatchingKernel(this, *values);
+ }
+};
+
+template <typename Type1, typename Type2>
+struct StringRepeatTransform : public StringBinaryTransformBase<Type1, Type2> {
+ using ArrayType1 = typename TypeTraits<Type1>::ArrayType;
+ using ArrayType2 = typename TypeTraits<Type2>::ArrayType;
+
+ int64_t MaxCodeunits(const int64_t input1_ncodeunits, const int64_t
num_repeats,
+ Status*) override {
+ return input1_ncodeunits * num_repeats;
+ }
+
+ int64_t MaxCodeunits(const int64_t input1_ncodeunits, const ArrayType2&
input2,
+ Status* st) override {
+ int64_t total_num_repeats = 0;
+ for (int64_t i = 0; i < input2.length(); ++i) {
+ auto num_repeats = input2.GetView(i);
+ if (num_repeats < 0) {
+ *st = InvalidRepeatCount();
+ return num_repeats;
+ }
+ total_num_repeats += num_repeats;
+ }
+ return input1_ncodeunits * total_num_repeats;
+ }
+
+ int64_t MaxCodeunits(const ArrayType1& input1, const int64_t num_repeats,
+ Status*) override {
+ return input1.total_values_length() * num_repeats;
+ }
+
+ int64_t MaxCodeunits(const ArrayType1& input1, const ArrayType2& input2,
+ Status* st) override {
+ int64_t total_codeunits = 0;
+ for (int64_t i = 0; i < input2.length(); ++i) {
+ auto num_repeats = input2.GetView(i);
+ if (num_repeats < 0) {
+ *st = InvalidRepeatCount();
+ return num_repeats;
+ }
+ total_codeunits += input1.GetView(i).length() * num_repeats;
+ }
+ return total_codeunits;
+ }
+
+ std::function<int64_t(const uint8_t*, const int64_t, const int64_t,
uint8_t*, Status*)>
+ Transform;
+
+ static int64_t TransformSimple(const uint8_t* input,
+ const int64_t input_string_ncodeunits,
+ const int64_t num_repeats, uint8_t* output,
Status*) {
+ uint8_t* output_start = output;
+ for (int64_t i = 0; i < num_repeats; ++i) {
+ std::memcpy(output, input, input_string_ncodeunits);
+ output += input_string_ncodeunits;
+ }
+ return output - output_start;
+ }
+
+ static int64_t TransformDoubling(const uint8_t* input,
+ const int64_t input_string_ncodeunits,
+ const int64_t num_repeats, uint8_t* output,
Status*) {
+ uint8_t* output_start = output;
+ // Repeated doubling of string
+ std::memcpy(output, input, input_string_ncodeunits);
+ output += input_string_ncodeunits;
+ int64_t i = 1;
+ for (int64_t ilen = input_string_ncodeunits; i <= (num_repeats / 2);
+ i *= 2, ilen *= 2) {
+ std::memcpy(output, output_start, ilen);
+ output += ilen;
+ }
+
+ // Epilogue remainder
+ int64_t rem = (num_repeats ^ i) * input_string_ncodeunits;
+ std::memcpy(output, output_start, rem);
+ output += rem;
+ return output - output_start;
+ }
+
+ static int64_t TransformWrapper(const uint8_t* input,
+ const int64_t input_string_ncodeunits,
+ const int64_t num_repeats, uint8_t* output,
+ Status* st) {
+ auto transform = (num_repeats < 4) ? TransformSimple : TransformDoubling;
+ return transform(input, input_string_ncodeunits, num_repeats, output, st);
+ }
+
+ Status PreExec(KernelContext*, const ExecBatch& batch, Datum*) override {
+ // For cases with a scalar repeat count, select the best implementation
once
+ // before execution. Otherwise, use TransformWrapper to select
implementation
+ // when processing each value.
Review comment:
Using `std::function` indirection resulted in 2x slower, so good
call/intuition on this one.
```
StringRepeat_mean 622822087 ns 622817699 ns 10
bytes_per_second=25.4509M/s items_per_second=1.68498M/s
StringRepeat_median 623393064 ns 623390528 ns 10
bytes_per_second=25.4067M/s items_per_second=1.68205M/s
StringRepeat_stddev 18771545 ns 18770743 ns 10
bytes_per_second=787.511k/s items_per_second=50.9153k/s
```
Checking `num_repeats < 4` at each iteration
```
StringRepeat_mean 313125674 ns 313123902 ns 10
bytes_per_second=50.601M/s items_per_second=3.35004M/s
StringRepeat_median 312795031 ns 312794088 ns 10
bytes_per_second=50.6405M/s items_per_second=3.35266M/s
StringRepeat_stddev 6484104 ns 6484645 ns 10
bytes_per_second=1068k/s items_per_second=69.0502k/s
```
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -19,6 +19,7 @@
#include <cctype>
#include <iterator>
#include <string>
+#include <typeinfo>
Review comment:
No, I used it when trying to print the StringTransform type using
`typeid(t).name()` but it printed more info than needed.
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -19,6 +19,7 @@
#include <cctype>
#include <iterator>
#include <string>
+#include <typeinfo>
Review comment:
No, I used it when trying to print the `StringTransform` type using
`typeid(t).name()` but it printed more info than needed.
##########
File path: cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
##########
@@ -20,6 +20,7 @@
#include <memory>
#include <string>
#include <utility>
+#include <vector>
Review comment:
Not for this PR, so I will revert. Nevertheless, I have noticed that
there are several imports missing and probably some extra in several files. I
think this should be its own JIRA issue.
##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -2513,6 +2878,135 @@ void AddSplit(FunctionRegistry* registry) {
#endif
}
+/// An ScalarFunction that promotes integer arguments to Int64.
+struct ScalarCTypeToInt64Function : public ScalarFunction {
+ using ScalarFunction::ScalarFunction;
+
+ Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const
override {
+ RETURN_NOT_OK(CheckArity(*values));
+
+ using arrow::compute::detail::DispatchExactImpl;
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+
+ EnsureDictionaryDecoded(values);
+
+ for (auto& descr : *values) {
+ if (is_integer(descr.type->id())) {
+ descr.type = int64();
+ }
+ }
+
+ if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ return arrow::compute::detail::NoMatchingKernel(this, *values);
+ }
+};
+
+template <typename Type1, typename Type2>
+struct StringRepeatTransform : public StringBinaryTransformBase<Type1, Type2> {
+ using ArrayType1 = typename TypeTraits<Type1>::ArrayType;
+ using ArrayType2 = typename TypeTraits<Type2>::ArrayType;
+
+ Result<int64_t> MaxCodeunits(const int64_t input1_ncodeunits,
+ const int64_t num_repeats) override {
+ ARROW_RETURN_NOT_OK(ValidateRepeatCount(num_repeats));
+ return input1_ncodeunits * num_repeats;
+ }
+
+ Result<int64_t> MaxCodeunits(const int64_t input1_ncodeunits,
+ const ArrayType2& input2) override {
+ int64_t total_num_repeats = 0;
+ for (int64_t i = 0; i < input2.length(); ++i) {
+ auto num_repeats = input2.GetView(i);
+ ARROW_RETURN_NOT_OK(ValidateRepeatCount(num_repeats));
+ total_num_repeats += num_repeats;
+ }
+ return input1_ncodeunits * total_num_repeats;
+ }
+
+ Result<int64_t> MaxCodeunits(const ArrayType1& input1,
+ const int64_t num_repeats) override {
+ ARROW_RETURN_NOT_OK(ValidateRepeatCount(num_repeats));
+ return input1.total_values_length() * num_repeats;
+ }
+
+ Result<int64_t> MaxCodeunits(const ArrayType1& input1,
+ const ArrayType2& input2) override {
+ int64_t total_codeunits = 0;
+ for (int64_t i = 0; i < input2.length(); ++i) {
+ auto num_repeats = input2.GetView(i);
+ ARROW_RETURN_NOT_OK(ValidateRepeatCount(num_repeats));
+ total_codeunits += input1.GetView(i).length() * num_repeats;
+ }
+ return total_codeunits;
+ }
+
+ static Result<int64_t> TransformSimpleLoop(const uint8_t* input,
+ const int64_t
input_string_ncodeunits,
+ const int64_t num_repeats,
uint8_t* output) {
+ uint8_t* output_start = output;
+ for (int64_t i = 0; i < num_repeats; ++i) {
+ std::memcpy(output, input, input_string_ncodeunits);
+ output += input_string_ncodeunits;
+ }
+ return output - output_start;
+ }
+
+ static Result<int64_t> TransformDoublingString(const uint8_t* input,
+ const int64_t
input_string_ncodeunits,
+ const int64_t num_repeats,
+ uint8_t* output) {
+ uint8_t* output_start = output;
+ // Repeated doubling of string
+ std::memcpy(output, input, input_string_ncodeunits);
+ output += input_string_ncodeunits;
+ int64_t i = 1;
+ for (int64_t ilen = input_string_ncodeunits; i <= (num_repeats / 2);
+ i *= 2, ilen *= 2) {
+ std::memcpy(output, output_start, ilen);
+ output += ilen;
+ }
+
+ // Epilogue remainder
+ int64_t rem = (num_repeats ^ i) * input_string_ncodeunits;
Review comment:
Not really, `xor` is just representing `mod 2` but in this case
subtraction is also valid.
Changed it to subtraction and renamed variable to `irep` for improved
readability.
##########
File path: cpp/src/arrow/compute/kernels/scalar_string_test.cc
##########
@@ -1041,6 +1037,73 @@ TYPED_TEST(TestStringKernels, Utf8Title) {
R"([null, "", "B", "Aaaz;Zææ&", "Ɑɽɽow", "Ii", "Ⱥ.Ⱥ.Ⱥ..Ⱥ", "Hello,
World!", "Foo Bar;Héhé0Zop", "!%$^.,;"])");
}
+TYPED_TEST(TestStringKernels, StringRepeatWithScalarRepeat) {
Review comment:
Yes, it is implicit in `CheckVarArgs`. `CheckVarArgs` invokes
[`CheckScalar` which internally calls function for each scalar
input](https://github.com/apache/arrow/blob/master/cpp/src/arrow/compute/kernels/test_util.cc#L127).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]