This is an automated email from the ASF dual-hosted git repository.
chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fury.git
The following commit(s) were added to refs/heads/main by this push:
new b32f3f9a feat(C++): String detection is performed using SIMD
techniques (#1720)
b32f3f9a is described below
commit b32f3f9acca14dd474d9bcf59786a8bffd02a0a4
Author: PAN <[email protected]>
AuthorDate: Sun Jul 14 15:11:29 2024 +0800
feat(C++): String detection is performed using SIMD techniques (#1720)
## What does this PR do?
ref: https://arxiv.org/pdf/1902.08318.pdf
ref: https://github.com/simdutf/simdutf
I learned about the related simd technology, as well as this paper and
project implementation.
Using SIMD technique for string detection.
First, I need to implement the logic and complete the latin character
detection
``` c++
// Baseline implementation
bool isLatin_Baseline(const std::string& str) {
for (char c : str) {
if (static_cast<unsigned char>(c) >= 128) {
return false;
}
}
return true;
}
```
<img width="393" alt="image"
src="https://raw.githubusercontent.com/pandalee99/image_store/master/hexo/simd_base_line_test1.png">
Then, I tried to use SSE2 to speed it up, which is obviously a little
bit faster, the logic is to read multiple characters at once and then do
the bit arithmetic
Obviously, there was a speed boost, but I didn't think it was enough, so
I tried it again with AVX2
<img width="493" alt="image"
src="https://raw.githubusercontent.com/pandalee99/image_store/master/hexo/simd_test_all_1.png">
I think in terms of efficiency, it's already much faster than before.
But how do you prove that it's also logically true?
I added test samples to verify
``` C++
TEST(StringUtilTest, TestIsLatinLogic)
```
Finally, I ran the test
<img width="493" alt="image"
src="https://raw.githubusercontent.com/pandalee99/image_store/master/hexo/simd_ubantu_test_1.png">
done.
<!-- Describe the purpose of this PR. -->
## Related issues
Closes #313
<!--
Is there any related issue? Please attach here.
- #xxxx0
- #xxxx1
- #xxxx2
-->
## Does this PR introduce any user-facing change?
<!--
If any user-facing interface changes, please [open an
issue](https://github.com/apache/fury/issues/new/choose) describing the
need to do so and update the document if necessary.
-->
- [x] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?
## Benchmark
<!--
When the PR has an impact on performance (if you don't know whether the
PR will have an impact on performance, you can submit the PR first, and
if it will have impact on performance, the code reviewer will explain
it), be sure to attach a benchmark data here.
-->
---
cpp/fury/util/BUILD | 11 ++++
cpp/fury/util/string_util.cc | 121 ++++++++++++++++++++++++++++++++++++++
cpp/fury/util/string_util.h | 28 +++++++++
cpp/fury/util/string_util_test.cc | 106 +++++++++++++++++++++++++++++++++
4 files changed, 266 insertions(+)
diff --git a/cpp/fury/util/BUILD b/cpp/fury/util/BUILD
index 1aa1b87c..8f605dc7 100644
--- a/cpp/fury/util/BUILD
+++ b/cpp/fury/util/BUILD
@@ -4,6 +4,8 @@ cc_library(
name = "fury_util",
srcs = glob(["*.cc"], exclude=["*test.cc"]),
hdrs = glob(["*.h"]),
+ copts = ["-mavx2"], # Enable AVX2 support
+ linkopts = ["-mavx2"], # Ensure linker also knows about AVX2
strip_include_prefix = "/cpp",
alwayslink=True,
linkstatic=True,
@@ -52,3 +54,12 @@ cc_test(
"@com_google_googletest//:gtest",
],
)
+
+cc_test(
+ name = "string_util_test",
+ srcs = ["string_util_test.cc"],
+ deps = [
+ ":fury_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
\ No newline at end of file
diff --git a/cpp/fury/util/string_util.cc b/cpp/fury/util/string_util.cc
new file mode 100644
index 00000000..1f57b76f
--- /dev/null
+++ b/cpp/fury/util/string_util.cc
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "string_util.h"
+
+#if defined(__x86_64__) || defined(_M_X64)
+#include <immintrin.h>
+#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
+#include <arm_neon.h>
+#elif defined(__riscv) && __riscv_vector
+#include <riscv_vector.h>
+#endif
+
+namespace fury {
+
+#if defined(__x86_64__) || defined(_M_X64)
+
+bool isLatin(const std::string &str) {
+ const char *data = str.data();
+ size_t len = str.size();
+
+ size_t i = 0;
+ __m256i latin_mask = _mm256_set1_epi8(0x80);
+ for (; i + 32 <= len; i += 32) {
+ __m256i chars =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i *>(data + i));
+ __m256i result = _mm256_and_si256(chars, latin_mask);
+ if (!_mm256_testz_si256(result, result)) {
+ return false;
+ }
+ }
+
+ for (; i < len; ++i) {
+ if (static_cast<unsigned char>(data[i]) >= 128) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
+
+bool isLatin(const std::string &str) {
+ const char *data = str.data();
+ size_t len = str.size();
+
+ size_t i = 0;
+ uint8x16_t latin_mask = vdupq_n_u8(0x80);
+ for (; i + 16 <= len; i += 16) {
+ uint8x16_t chars = vld1q_u8(reinterpret_cast<const uint8_t *>(data + i));
+ uint8x16_t result = vandq_u8(chars, latin_mask);
+ if (vmaxvq_u8(result) != 0) {
+ return false;
+ }
+ }
+
+ for (; i < len; ++i) {
+ if (static_cast<unsigned char>(data[i]) >= 128) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+#elif defined(__riscv) && __riscv_vector
+
+bool isLatin(const std::string &str) {
+ const char *data = str.data();
+ size_t len = str.size();
+
+ size_t i = 0;
+ for (; i + 16 <= len; i += 16) {
+ auto chars = vle8_v_u8m1(reinterpret_cast<const uint8_t *>(data + i), 16);
+ auto mask = vmv_v_x_u8m1(0x80, 16);
+ auto result = vand_vv_u8m1(chars, mask, 16);
+ if (vmax_v_u8m1(result, 16) != 0) {
+ return false;
+ }
+ }
+
+ for (; i < len; ++i) {
+ if (static_cast<unsigned char>(data[i]) >= 128) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+#else
+
+bool isLatin(const std::string &str) {
+ for (char c : str) {
+ if (static_cast<unsigned char>(c) >= 128) {
+ return false;
+ }
+ }
+ return true;
+}
+
+#endif
+
+} // namespace fury
diff --git a/cpp/fury/util/string_util.h b/cpp/fury/util/string_util.h
new file mode 100644
index 00000000..0824d1a2
--- /dev/null
+++ b/cpp/fury/util/string_util.h
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#pragma once
+
+#include <string>
+
+namespace fury {
+
+bool isLatin(const std::string &str);
+
+} // namespace fury
diff --git a/cpp/fury/util/string_util_test.cc
b/cpp/fury/util/string_util_test.cc
new file mode 100644
index 00000000..045454db
--- /dev/null
+++ b/cpp/fury/util/string_util_test.cc
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <chrono>
+#include <iostream>
+#include <random>
+
+#include "fury/util/logging.h"
+#include "string_util.h"
+#include "gtest/gtest.h"
+
+namespace fury {
+
+// Function to generate a random string
+std::string generateRandomString(size_t length) {
+ const char charset[] =
+ "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
+ std::default_random_engine rng(std::random_device{}());
+ std::uniform_int_distribution<> dist(0, sizeof(charset) - 2);
+
+ std::string result;
+ result.reserve(length);
+ for (size_t i = 0; i < length; ++i) {
+ result += charset[dist(rng)];
+ }
+
+ return result;
+}
+
+bool isLatin_BaseLine(const std::string &str) {
+ for (char c : str) {
+ if (static_cast<unsigned char>(c) >= 128) {
+ return false;
+ }
+ }
+ return true;
+}
+
+TEST(StringUtilTest, TestIsLatinFunctions) {
+ std::string testStr = generateRandomString(100000);
+ auto start_time = std::chrono::high_resolution_clock::now();
+ bool result = isLatin_BaseLine(testStr);
+ auto end_time = std::chrono::high_resolution_clock::now();
+ auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(
+ end_time - start_time)
+ .count();
+ FURY_LOG(INFO) << "BaseLine Running Time: " << duration << " ns.";
+
+ start_time = std::chrono::high_resolution_clock::now();
+ result = isLatin(testStr);
+ end_time = std::chrono::high_resolution_clock::now();
+ duration = std::chrono::duration_cast<std::chrono::nanoseconds>(end_time -
+ start_time)
+ .count();
+ FURY_LOG(INFO) << "Optimized Running Time: " << duration << " ns.";
+
+ EXPECT_TRUE(result);
+}
+
+TEST(StringUtilTest, TestIsLatinLogic) {
+ // Test strings with only Latin characters
+ EXPECT_TRUE(isLatin("Fury"));
+ EXPECT_TRUE(isLatin(generateRandomString(80)));
+
+ // Test unaligned strings with only Latin characters
+ EXPECT_TRUE(isLatin(generateRandomString(80) + "1"));
+ EXPECT_TRUE(isLatin(generateRandomString(80) + "12"));
+ EXPECT_TRUE(isLatin(generateRandomString(80) + "123"));
+
+ // Test strings with non-Latin characters
+ EXPECT_FALSE(isLatin("你好, Fury"));
+ EXPECT_FALSE(isLatin(generateRandomString(80) + "你好"));
+ EXPECT_FALSE(isLatin(generateRandomString(80) + "1你好"));
+ EXPECT_FALSE(isLatin(generateRandomString(11) + "你"));
+ EXPECT_FALSE(isLatin(generateRandomString(10) + "你好"));
+ EXPECT_FALSE(isLatin(generateRandomString(9) + "性能好"));
+ EXPECT_FALSE(isLatin("\u1234"));
+ EXPECT_FALSE(isLatin("a\u1234"));
+ EXPECT_FALSE(isLatin("ab\u1234"));
+ EXPECT_FALSE(isLatin("abc\u1234"));
+ EXPECT_FALSE(isLatin("abcd\u1234"));
+ EXPECT_FALSE(isLatin("Javaone Keynote\u1234"));
+}
+
+} // namespace fury
+
+int main(int argc, char **argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]