This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 2b4a703202 GH-39231: [C++][Compute] Add binary_slice kernel for fixed
size binary (#39245)
2b4a703202 is described below
commit 2b4a70320232647f730b19d2fea5746c3baec752
Author: Jin Shang <[email protected]>
AuthorDate: Fri Jan 12 01:56:46 2024 +0800
GH-39231: [C++][Compute] Add binary_slice kernel for fixed size binary
(#39245)
### Rationale for this change
Add binary_slice kernel for fixed size binary
### What changes are included in this PR?
Add binary_slice kernel for fixed size binary
### Are these changes tested?
Yes
### Are there any user-facing changes?
No
* Closes: #39231
Lead-authored-by: Jin Shang <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
.../arrow/compute/kernels/scalar_string_ascii.cc | 117 ++++++++++++-----
.../arrow/compute/kernels/scalar_string_internal.h | 2 +
.../arrow/compute/kernels/scalar_string_test.cc | 146 +++++++++++++++++++--
python/pyarrow/tests/test_compute.py | 10 +-
4 files changed, 233 insertions(+), 42 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
index 6764845dfc..8fdc6172aa 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
@@ -95,7 +95,7 @@ struct FixedSizeBinaryTransformExecBase {
ctx->Allocate(output_width * input_nstrings));
uint8_t* output_str = values_buffer->mutable_data();
- const uint8_t* input_data = input.GetValues<uint8_t>(1);
+ const uint8_t* input_data = input.GetValues<uint8_t>(1, input.offset *
input_width);
for (int64_t i = 0; i < input_nstrings; i++) {
if (!input.IsNull(i)) {
const uint8_t* input_string = input_data + i * input_width;
@@ -132,7 +132,8 @@ struct FixedSizeBinaryTransformExecWithState
DCHECK_EQ(1, types.size());
const auto& options = State::Get(ctx);
const int32_t input_width = types[0].type->byte_width();
- const int32_t output_width = StringTransform::FixedOutputSize(options,
input_width);
+ ARROW_ASSIGN_OR_RAISE(const int32_t output_width,
+ StringTransform::FixedOutputSize(options,
input_width));
return fixed_size_binary(output_width);
}
};
@@ -2377,7 +2378,8 @@ struct BinaryReplaceSliceTransform :
ReplaceStringSliceTransformBase {
return output - output_start;
}
- static int32_t FixedOutputSize(const ReplaceSliceOptions& opts, int32_t
input_width) {
+ static Result<int32_t> FixedOutputSize(const ReplaceSliceOptions& opts,
+ int32_t input_width) {
int32_t before_slice = 0;
int32_t after_slice = 0;
const int32_t start = static_cast<int32_t>(opts.start);
@@ -2436,6 +2438,7 @@ void AddAsciiStringReplaceSlice(FunctionRegistry*
registry) {
namespace {
struct SliceBytesTransform : StringSliceTransformBase {
+ using StringSliceTransformBase::StringSliceTransformBase;
int64_t MaxCodeunits(int64_t ninputs, int64_t input_bytes) override {
const SliceOptions& opt = *this->options;
if ((opt.start >= 0) != (opt.stop >= 0)) {
@@ -2454,22 +2457,15 @@ struct SliceBytesTransform : StringSliceTransformBase {
return SliceBackward(input, input_string_bytes, output);
}
- int64_t SliceForward(const uint8_t* input, int64_t input_string_bytes,
- uint8_t* output) {
- // Slice in forward order (step > 0)
- const SliceOptions& opt = *this->options;
- const uint8_t* begin = input;
- const uint8_t* end = input + input_string_bytes;
- const uint8_t* begin_sliced;
- const uint8_t* end_sliced;
-
- if (!input_string_bytes) {
- return 0;
- }
- // First, compute begin_sliced and end_sliced
+ static std::pair<int64_t, int64_t> SliceForwardRange(const SliceOptions& opt,
+ int64_t
input_string_bytes) {
+ int64_t begin = 0;
+ int64_t end = input_string_bytes;
+ int64_t begin_sliced = 0;
+ int64_t end_sliced = 0;
if (opt.start >= 0) {
// start counting from the left
- begin_sliced = std::min(begin + opt.start, end);
+ begin_sliced = std::min(opt.start, end);
if (opt.stop > opt.start) {
// continue counting from begin_sliced
const int64_t length = opt.stop - opt.start;
@@ -2479,7 +2475,7 @@ struct SliceBytesTransform : StringSliceTransformBase {
end_sliced = std::max(end + opt.stop, begin_sliced);
} else {
// zero length slice
- return 0;
+ return {0, 0};
}
} else {
// start counting from the right
@@ -2491,7 +2487,7 @@ struct SliceBytesTransform : StringSliceTransformBase {
// and therefore we also need this
if (end_sliced <= begin_sliced) {
// zero length slice
- return 0;
+ return {0, 0};
}
} else if ((opt.stop < 0) && (opt.stop > opt.start)) {
// stop is negative, but larger than start, so we count again from the
right
@@ -2501,12 +2497,30 @@ struct SliceBytesTransform : StringSliceTransformBase {
end_sliced = std::max(end + opt.stop, begin_sliced);
} else {
// zero length slice
- return 0;
+ return {0, 0};
}
}
+ return {begin_sliced, end_sliced};
+ }
+
+ int64_t SliceForward(const uint8_t* input, int64_t input_string_bytes,
+ uint8_t* output) {
+ // Slice in forward order (step > 0)
+ if (!input_string_bytes) {
+ return 0;
+ }
+
+ const SliceOptions& opt = *this->options;
+ auto [begin_index, end_index] = SliceForwardRange(opt, input_string_bytes);
+ const uint8_t* begin_sliced = input + begin_index;
+ const uint8_t* end_sliced = input + end_index;
+
+ if (begin_sliced == end_sliced) {
+ return 0;
+ }
// Second, copy computed slice to output
- DCHECK(begin_sliced <= end_sliced);
+ DCHECK(begin_sliced < end_sliced);
if (opt.step == 1) {
// fast case, where we simply can finish with a memcpy
std::copy(begin_sliced, end_sliced, output);
@@ -2525,18 +2539,13 @@ struct SliceBytesTransform : StringSliceTransformBase {
return dest - output;
}
- int64_t SliceBackward(const uint8_t* input, int64_t input_string_bytes,
- uint8_t* output) {
+ static std::pair<int64_t, int64_t> SliceBackwardRange(const SliceOptions&
opt,
+ int64_t
input_string_bytes) {
// Slice in reverse order (step < 0)
- const SliceOptions& opt = *this->options;
- const uint8_t* begin = input;
- const uint8_t* end = input + input_string_bytes;
- const uint8_t* begin_sliced = begin;
- const uint8_t* end_sliced = end;
-
- if (!input_string_bytes) {
- return 0;
- }
+ int64_t begin = 0;
+ int64_t end = input_string_bytes;
+ int64_t begin_sliced = begin;
+ int64_t end_sliced = end;
if (opt.start >= 0) {
// +1 because begin_sliced acts as as the end of a reverse iterator
@@ -2555,6 +2564,28 @@ struct SliceBytesTransform : StringSliceTransformBase {
}
end_sliced--;
+ if (begin_sliced <= end_sliced) {
+ // zero length slice
+ return {0, 0};
+ }
+
+ return {begin_sliced, end_sliced};
+ }
+
+ int64_t SliceBackward(const uint8_t* input, int64_t input_string_bytes,
+ uint8_t* output) {
+ if (!input_string_bytes) {
+ return 0;
+ }
+
+ const SliceOptions& opt = *this->options;
+ auto [begin_index, end_index] = SliceBackwardRange(opt,
input_string_bytes);
+ const uint8_t* begin_sliced = input + begin_index;
+ const uint8_t* end_sliced = input + end_index;
+
+ if (begin_sliced == end_sliced) {
+ return 0;
+ }
// Copy computed slice to output
uint8_t* dest = output;
const uint8_t* i = begin_sliced;
@@ -2568,6 +2599,22 @@ struct SliceBytesTransform : StringSliceTransformBase {
return dest - output;
}
+
+ static Result<int32_t> FixedOutputSize(SliceOptions options, int32_t
input_width_32) {
+ auto step = options.step;
+ if (step == 0) {
+ return Status::Invalid("Slice step cannot be zero");
+ }
+ if (step > 0) {
+ // forward slice
+ auto [begin_index, end_index] = SliceForwardRange(options,
input_width_32);
+ return static_cast<int32_t>((end_index - begin_index + step - 1) / step);
+ } else {
+ // backward slice
+ auto [begin_index, end_index] = SliceBackwardRange(options,
input_width_32);
+ return static_cast<int32_t>((end_index - begin_index + step + 1) / step);
+ }
+ }
};
template <typename Type>
@@ -2594,6 +2641,12 @@ void AddAsciiStringSlice(FunctionRegistry* registry) {
DCHECK_OK(
func->AddKernel({ty}, ty, std::move(exec),
SliceBytesTransform::State::Init));
}
+ using TransformExec =
FixedSizeBinaryTransformExecWithState<SliceBytesTransform>;
+ ScalarKernel fsb_kernel({InputType(Type::FIXED_SIZE_BINARY)},
+ OutputType(TransformExec::OutputType),
TransformExec::Exec,
+ StringSliceTransformBase::State::Init);
+ fsb_kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ DCHECK_OK(func->AddKernel(std::move(fsb_kernel)));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_internal.h
b/cpp/src/arrow/compute/kernels/scalar_string_internal.h
index 7a5d5a7c86..6723d11c8d 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_internal.h
+++ b/cpp/src/arrow/compute/kernels/scalar_string_internal.h
@@ -250,6 +250,8 @@ struct StringSliceTransformBase : public
StringTransformBase {
using State = OptionsWrapper<SliceOptions>;
const SliceOptions* options;
+ StringSliceTransformBase() = default;
+ explicit StringSliceTransformBase(const SliceOptions& options) :
options{&options} {}
Status PreExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out)
override {
options = &State::Get(ctx);
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc
b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
index 5dec16d89e..d7e35d0733 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
@@ -33,10 +33,10 @@
#include "arrow/compute/kernels/test_util.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/type.h"
+#include "arrow/type_fwd.h"
#include "arrow/util/value_parsing.h"
-namespace arrow {
-namespace compute {
+namespace arrow::compute {
// interesting utf8 characters for testing (lower case / upper case):
// * ῦ / Υ͂ (3 to 4 code units) (Note, we don't support this yet, utf8proc
does not use
@@ -712,11 +712,140 @@ TEST_F(TestFixedSizeBinaryKernels, BinaryLength) {
"[6, null, 6]");
}
+TEST_F(TestFixedSizeBinaryKernels, BinarySliceEmpty) {
+ SliceOptions options{2, 4};
+ CheckScalarUnary("binary_slice", ArrayFromJSON(fixed_size_binary(0),
R"([""])"),
+ ArrayFromJSON(fixed_size_binary(0), R"([""])"), &options);
+
+ CheckScalarUnary("binary_slice",
+ ArrayFromJSON(fixed_size_binary(0), R"(["", null, ""])"),
+ ArrayFromJSON(fixed_size_binary(0), R"(["", null, ""])"),
&options);
+
+ CheckUnary("binary_slice", R"([null, null])", fixed_size_binary(2),
R"([null, null])",
+ &options);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, BinarySliceBasic) {
+ SliceOptions options{2, 4};
+ CheckUnary("binary_slice", R"(["abcdef", null, "foobaz"])",
fixed_size_binary(2),
+ R"(["cd", null, "ob"])", &options);
+
+ SliceOptions options_edgecase_1{-3, 1};
+ CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(0),
+ R"(["", ""])", &options_edgecase_1);
+
+ SliceOptions options_edgecase_2{-10, -3};
+ CheckUnary("binary_slice", R"(["abcdef", "foobaz", null])",
fixed_size_binary(3),
+ R"(["abc", "foo", null])", &options_edgecase_2);
+
+ auto input = ArrayFromJSON(this->type(), R"(["foobaz"])");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr("Function 'binary_slice' cannot be called without
options"),
+ CallFunction("binary_slice", {input}));
+
+ SliceOptions options_invalid{2, 4, 0};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("Slice step cannot be zero"),
+ CallFunction("binary_slice", {input}, &options_invalid));
+}
+
+TEST_F(TestFixedSizeBinaryKernels, BinarySlicePosPos) {
+ SliceOptions options_step{1, 5, 2};
+ CheckUnary("binary_slice", R"([null, "abcdef", "foobaz"])",
fixed_size_binary(2),
+ R"([null, "bd", "ob"])", &options_step);
+
+ SliceOptions options_step_neg{5, 0, -2};
+ CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(3),
+ R"(["fdb", "zbo"])", &options_step_neg);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, BinarySlicePosNeg) {
+ SliceOptions options{2, -1};
+ CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(3),
+ R"(["cde", "oba"])", &options);
+
+ SliceOptions options_step{1, -1, 2};
+ CheckUnary("binary_slice", R"(["abcdef", null, "foobaz"])",
fixed_size_binary(2),
+ R"(["bd", null, "ob"])", &options_step);
+
+ SliceOptions options_step_neg{5, -4, -2};
+ CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(2),
+ R"(["fd", "zb"])", &options_step_neg);
+
+ options_step_neg.stop = -6;
+ CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(3),
+ R"(["fdb", "zbo"])", &options_step_neg);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, BinarySliceNegNeg) {
+ SliceOptions options{-2, -1};
+ CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(1),
+ R"(["e", "a"])", &options);
+
+ SliceOptions options_step{-4, -1, 2};
+ CheckUnary("binary_slice", R"(["abcdef", "foobaz", null, null])",
fixed_size_binary(2),
+ R"(["ce", "oa", null, null])", &options_step);
+
+ SliceOptions options_step_neg{-1, -3, -2};
+ CheckUnary("binary_slice", R"([null, "abcdef", null, "foobaz"])",
fixed_size_binary(1),
+ R"([null, "f", null, "z"])", &options_step_neg);
+
+ options_step_neg.stop = -4;
+ CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(2),
+ R"(["fd", "zb"])", &options_step_neg);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, BinarySliceNegPos) {
+ SliceOptions options{-2, 4};
+ CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(0),
+ R"(["", ""])", &options);
+
+ SliceOptions options_step{-4, 5, 2};
+ CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(2),
+ R"(["ce", "oa"])", &options_step);
+
+ SliceOptions options_step_neg{-1, 1, -2};
+ CheckUnary("binary_slice", R"([null, "abcdef", "foobaz", null])",
fixed_size_binary(2),
+ R"([null, "fd", "zb", null])", &options_step_neg);
+
+ options_step_neg.stop = 0;
+ CheckUnary("binary_slice", R"(["abcdef", "foobaz"])", fixed_size_binary(3),
+ R"(["fdb", "zbo"])", &options_step_neg);
+}
+
+TEST_F(TestFixedSizeBinaryKernels, BinarySliceConsistentyWithVarLenBinary) {
+ std::string source_str = "abcdef";
+ for (size_t str_len = 0; str_len < source_str.size(); ++str_len) {
+ auto input_str = source_str.substr(0, str_len);
+ auto fixed_input =
ArrayFromJSON(fixed_size_binary(static_cast<int32_t>(str_len)),
+ R"([")" + input_str + R"("])");
+ auto varlen_input = ArrayFromJSON(binary(), R"([")" + input_str + R"("])");
+ for (auto start = -6; start <= 6; ++start) {
+ for (auto stop = -6; stop <= 6; ++stop) {
+ for (auto step = -3; step <= 4; ++step) {
+ if (step == 0) {
+ continue;
+ }
+ SliceOptions options{start, stop, step};
+ auto expected =
+ CallFunction("binary_slice", {varlen_input},
&options).ValueOrDie();
+ auto actual =
+ CallFunction("binary_slice", {fixed_input},
&options).ValueOrDie();
+ actual = Cast(actual, binary()).ValueOrDie();
+ ASSERT_OK(actual.make_array()->ValidateFull());
+ AssertDatumsEqual(expected, actual);
+ }
+ }
+ }
+ }
+}
+
TEST_F(TestFixedSizeBinaryKernels, BinaryReplaceSlice) {
ReplaceSliceOptions options{0, 1, "XX"};
CheckUnary("binary_replace_slice", "[]", fixed_size_binary(7), "[]",
&options);
- CheckUnary("binary_replace_slice", R"([null, "abcdef"])",
fixed_size_binary(7),
- R"([null, "XXbcdef"])", &options);
+ CheckUnary("binary_replace_slice", R"(["foobaz", null, "abcdef"])",
+ fixed_size_binary(7), R"(["XXoobaz", null, "XXbcdef"])",
&options);
ReplaceSliceOptions options_shrink{0, 2, ""};
CheckUnary("binary_replace_slice", R"([null, "abcdef"])",
fixed_size_binary(4),
@@ -731,8 +860,8 @@ TEST_F(TestFixedSizeBinaryKernels, BinaryReplaceSlice) {
R"([null, "abXXef"])", &options_middle);
ReplaceSliceOptions options_neg_start{-3, -2, "XX"};
- CheckUnary("binary_replace_slice", R"([null, "abcdef"])",
fixed_size_binary(7),
- R"([null, "abcXXef"])", &options_neg_start);
+ CheckUnary("binary_replace_slice", R"(["foobaz", null, "abcdef"])",
+ fixed_size_binary(7), R"(["fooXXaz", null, "abcXXef"])",
&options_neg_start);
ReplaceSliceOptions options_neg_end{2, -2, "XX"};
CheckUnary("binary_replace_slice", R"([null, "abcdef"])",
fixed_size_binary(6),
@@ -807,7 +936,7 @@ TEST_F(TestFixedSizeBinaryKernels,
CountSubstringIgnoreCase) {
offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 1]", &options);
MatchSubstringOptions options_empty{"", /*ignore_case=*/true};
- CheckUnary("count_substring", R"([" ", null, "abcABc"])", offset_type(),
+ CheckUnary("count_substring", R"([" ", null, "abcdef"])", offset_type(),
"[7, null, 7]", &options_empty);
}
@@ -2382,5 +2511,4 @@ TEST(TestStringKernels, UnicodeLibraryAssumptions) {
}
#endif
-} // namespace compute
-} // namespace arrow
+} // namespace arrow::compute
diff --git a/python/pyarrow/tests/test_compute.py
b/python/pyarrow/tests/test_compute.py
index 7c5a134d33..d1eb605c71 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -561,7 +561,8 @@ def test_slice_compatibility():
def test_binary_slice_compatibility():
- arr = pa.array([b"", b"a", b"a\xff", b"ab\x00", b"abc\xfb", b"ab\xf2de"])
+ data = [b"", b"a", b"a\xff", b"ab\x00", b"abc\xfb", b"ab\xf2de"]
+ arr = pa.array(data)
for start, stop, step in itertools.product(range(-6, 6),
range(-6, 6),
range(-3, 4)):
@@ -574,6 +575,13 @@ def test_binary_slice_compatibility():
assert expected.equals(result)
# Positional options
assert pc.binary_slice(arr, start, stop, step) == result
+ # Fixed size binary input / output
+ for item in data:
+ fsb_scalar = pa.scalar(item, type=pa.binary(len(item)))
+ expected = item[start:stop:step]
+ actual = pc.binary_slice(fsb_scalar, start, stop, step)
+ assert actual.type == pa.binary(len(expected))
+ assert actual.as_py() == expected
def test_split_pattern():