This is an automated email from the ASF dual-hosted git repository.

changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 6671dd0b5 [GLUTEN-6989][CH] Support RTrim with const source column 
(#6992)
6671dd0b5 is described below

commit 6671dd0b5cebf2e4660b94dce4038005f259af7d
Author: Wenzheng Liu <[email protected]>
AuthorDate: Thu Aug 29 10:04:10 2024 +0800

    [GLUTEN-6989][CH] Support RTrim with const source column (#6992)
    
    * [GLUTEN-6989][CH] Support RTrim with const source column
    * [GLUTEN-6989][CH] Use bitmap256 to trim
    * [GLUTEN-6989][CH] Fix comments
    * [GLUTEN-6989][CH] Fix core dump
    * [GLUTEN-6989][CH] Remove condition when trim col contains null
---
 .../GlutenClickhouseStringFunctionsSuite.scala     |  38 +++++
 .../local-engine/Functions/SparkFunctionTrim.cpp   | 173 ++++++++++-----------
 2 files changed, 117 insertions(+), 94 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala
index 98c0c2b35..e40b293ea 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala
@@ -49,6 +49,44 @@ class GlutenClickhouseStringFunctionsSuite extends 
GlutenClickHouseWholeStageTra
     }
   }
 
+  test("GLUTEN-6989: rtrim support source column const") {
+    withTable("trim") {
+      sql("create table trim(trim_col String, src_col String) using parquet")
+      sql("""
+            |insert into trim values
+            |  ('bAa', 'a'),('bba', 'b'),('abcdef', 'abcd'),
+            |  (null, '123'),('123', null), ('', 'aaa'), ('bbb', '')
+            |""".stripMargin)
+
+      val sql0 = "select rtrim('aba', 'a') from trim order by src_col"
+      val sql1 = "select rtrim(trim_col, src_col) from trim order by src_col"
+      val sql2 = "select rtrim(trim_col, 'cCBbAa') from trim order by src_col"
+      val sql3 = "select rtrim(trim_col, '') from trim order by src_col"
+      val sql4 = "select rtrim('', 'AAA') from trim order by src_col"
+      val sql5 = "select rtrim('', src_col) from trim order by src_col"
+      val sql6 = "select rtrim('ab', src_col) from trim order by src_col"
+
+      runQueryAndCompare(sql0) { _ => }
+      runQueryAndCompare(sql1) { _ => }
+      runQueryAndCompare(sql2) { _ => }
+      runQueryAndCompare(sql3) { _ => }
+      runQueryAndCompare(sql4) { _ => }
+      runQueryAndCompare(sql5) { _ => }
+      runQueryAndCompare(sql6) { _ => }
+
+      // test other trim functions
+      val sql7 = "SELECT trim(LEADING trim_col FROM src_col) from trim"
+      val sql8 = "SELECT trim(LEADING trim_col FROM 'NSB') from trim"
+      val sql9 = "SELECT trim(TRAILING trim_col FROM src_col) from trim"
+      val sql10 = "SELECT trim(TRAILING trim_col FROM '') from trim"
+      runQueryAndCompare(sql7) { _ => }
+      runQueryAndCompare(sql8) { _ => }
+      runQueryAndCompare(sql9) { _ => }
+      runQueryAndCompare(sql10) { _ => }
+
+    }
+  }
+
   test("GLUTEN-5897: fix regexp_extract with bracket") {
     withTable("regexp_extract_bracket") {
       sql("create table regexp_extract_bracket(a String) using parquet")
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp 
b/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp
index 88ed3f635..fbbd944c0 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp
@@ -106,125 +106,110 @@ namespace
         ColumnPtr
         executeImpl(const ColumnsWithTypeAndName & arguments, const 
DataTypePtr & /*result_type*/, size_t input_rows_count) const override
         {
-            const ColumnString * src_str_col = 
checkAndGetColumn<ColumnString>(arguments[0].column.get());
-            if (!src_str_col)
-                throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of 
function {} must be String", getName());
+            const ColumnString * src_col = 
checkAndGetColumn<ColumnString>(arguments[0].column.get());
+            const ColumnConst * src_const_col = 
checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
+            const ColumnString * trim_col = 
checkAndGetColumn<ColumnString>(arguments[1].column.get());
+            const ColumnConst * trim_const_col = 
checkAndGetColumnConst<ColumnString>(arguments[1].column.get());
+
+            String src_const_str;
+            String trim_const_str;
+            if (src_const_col)
+                src_const_str = src_const_col->getValue<String>();
+            if (trim_const_col)
+            {
+                trim_const_str = trim_const_col->getValue<String>();
+                if (trim_const_str.empty())
+                {
+                    return arguments[0].column;
+                }
+            }
 
+            // If both arguments are constants, it will be simplified to a 
constant. Skipped here.
 
-            if (const auto * trim_const_str_col = 
checkAndGetColumnConst<ColumnString>(arguments[1].column.get()))
-            {
-                String trim_str = trim_const_str_col->getValue<String>();
-                if (trim_str.empty())
-                    return src_str_col->cloneResized(input_rows_count);
+            auto res_col = ColumnString::create();
+            ColumnString::Chars & res_data = res_col->getChars();
+            ColumnString::Offsets & res_offsets = res_col->getOffsets();
+            res_offsets.resize_exact(input_rows_count);
 
-                auto res_col = ColumnString::create();
-                res_col->reserve(input_rows_count);
-                executeVector(src_str_col->getChars(), 
src_str_col->getOffsets(), res_col->getChars(), res_col->getOffsets(), 
trim_str);
+            // Source column is constant and trim column is not constant
+            if (src_const_col)
+            {
+                res_data.reserve_exact(src_const_str.size() * 
input_rows_count);
+                for (size_t row = 0; row < input_rows_count; ++row)
+                {
+                    StringRef trim_str_ref = trim_col->getDataAt(row);
+                    std::unique_ptr<std::bitset<256>> trim_set = 
buildTrimSet(trim_str_ref.data, trim_str_ref.size);
+                    executeRow(src_const_str.c_str(), src_const_str.size(), 
res_data, res_offsets, row, *trim_set);
+                }
                 return std::move(res_col);
             }
-            else if (const auto * trim_str_col = 
checkAndGetColumn<ColumnString>(arguments[1].column.get()))
+
+            // Source column is not constant and trim column is constant
+            if (trim_const_col)
             {
-                auto res_col = ColumnString::create();
-                res_col->reserve(input_rows_count);
-
-                executeVector(
-                    src_str_col->getChars(),
-                    src_str_col->getOffsets(),
-                    res_col->getChars(),
-                    res_col->getOffsets(),
-                    trim_str_col->getChars(),
-                    trim_str_col->getOffsets());
+                res_data.reserve_exact(src_col->getChars().size());
+                std::unique_ptr<std::bitset<256>> trim_set = 
buildTrimSet(trim_const_str.c_str(), trim_const_str.size());
+                for (size_t row = 0; row < input_rows_count; ++row)
+                {
+                    StringRef src_str_ref = src_col->getDataAt(row);
+                    executeRow(src_str_ref.data, src_str_ref.size, res_data, 
res_offsets, row, *trim_set);
+                }
                 return std::move(res_col);
             }
 
-            throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Second argument of 
function {} must be String or Const String", getName());
+            // Both columns are not constant
+            res_data.reserve(src_col->getChars().size());
+            for (size_t row = 0; row < input_rows_count; ++row)
+            {
+                StringRef src_str_ref = src_col->getDataAt(row);
+                StringRef trim_str_ref = trim_col->getDataAt(row);
+                std::unique_ptr<std::bitset<256>> trim_set = 
buildTrimSet(trim_str_ref.data, trim_str_ref.size);
+                executeRow(src_str_ref.data, src_str_ref.size, res_data, 
res_offsets, row, *trim_set);
+            }
+            return std::move(res_col);
         }
 
     private:
-        void executeVector(
-            const ColumnString::Chars & data,
-            const ColumnString::Offsets & offsets,
+        void executeRow(
+            const char * src,
+            size_t src_size,
             ColumnString::Chars & res_data,
             ColumnString::Offsets & res_offsets,
-            const String & trim_str) const
+            size_t row,
+            const std::bitset<256> & trim_set) const
         {
-            res_data.reserve_exact(data.size());
-
-            size_t rows = offsets.size();
-            res_offsets.resize_exact(rows);
-
-            size_t prev_offset = 0;
-            size_t res_offset = 0;
-
-            const UInt8 * start;
-            size_t length;
-            std::unordered_set<char> trim_set(trim_str.begin(), 
trim_str.end());
-            for (size_t i = 0; i < rows; ++i)
-            {
-                trim(reinterpret_cast<const UInt8 *>(&data[prev_offset]), 
offsets[i] - prev_offset - 1, start, length, trim_set);
-                res_data.resize_exact(res_data.size() + length + 1);
-                memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], 
start, length);
-                res_offset += length + 1;
-                res_data[res_offset - 1] = '\0';
-
-                res_offsets[i] = res_offset;
-                prev_offset = offsets[i];
-            }
+            const char * dst;
+            size_t dst_size;
+            trim(src, src_size, dst, dst_size, trim_set);
+            size_t res_offset = row > 0 ? res_offsets[row - 1] : 0;
+            res_data.resize_exact(res_data.size() + dst_size + 1);
+            memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], dst, 
dst_size);
+            res_offset += dst_size + 1;
+            res_data[res_offset - 1] = '\0';
+            res_offsets[row] = res_offset;
         }
 
-        void executeVector(
-            const ColumnString::Chars & data,
-            const ColumnString::Offsets & offsets,
-            ColumnString::Chars & res_data,
-            ColumnString::Offsets & res_offsets,
-            const ColumnString::Chars & trim_data,
-            const ColumnString::Offsets & trim_offsets) const
+        std::unique_ptr<std::bitset<256>> buildTrimSet(const char* data, const 
size_t size) const
         {
-            res_data.reserve_exact(data.size());
-
-            size_t rows = offsets.size();
-            res_offsets.resize_exact(rows);
-
-            size_t prev_offset = 0;
-            size_t prev_trim_str_offset = 0;
-            size_t res_offset = 0;
-
-            const UInt8 * start;
-            size_t length;
-
-            for (size_t i = 0; i < rows; ++i)
-            {
-                std::unordered_set<char> trim_set(
-                    &trim_data[prev_trim_str_offset], 
&trim_data[prev_trim_str_offset] + trim_offsets[i] - prev_trim_str_offset - 1);
-
-                trim(reinterpret_cast<const UInt8 *>(&data[prev_offset]), 
offsets[i] - prev_offset - 1, start, length, trim_set);
-                res_data.resize_exact(res_data.size() + length + 1);
-                memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], 
start, length);
-                res_offset += length + 1;
-                res_data[res_offset - 1] = '\0';
-
-                res_offsets[i] = res_offset;
-                prev_offset = offsets[i];
-                prev_trim_str_offset = trim_offsets[i];
-            }
+            auto trim_set = std::make_unique<std::bitset<256>>();
+            for (size_t i = 0; i < size; ++i)
+                trim_set->set((unsigned char)data[i]);
+            return trim_set;
         }
 
-        void
-        trim(const UInt8 * data, size_t size, const UInt8 *& res_data, size_t 
& res_size, const std::unordered_set<char> & trim_set) const
+        void trim(const char * src, const size_t src_size, const char *& dst, 
size_t & dst_size, const std::bitset<256> & trim_set) const
         {
-            const char * char_data = reinterpret_cast<const char *>(data);
-            const char * char_end = char_data + size;
-
+            const char * src_end = src + src_size;
             if constexpr (TrimMode::trim_left)
-                while (char_data < char_end && trim_set.contains(*char_data))
-                    ++char_data;
+                while (src < src_end && trim_set.test((unsigned char)*src))
+                    ++src;
 
             if constexpr (TrimMode::trim_right)
-                while (char_data < char_end && trim_set.contains(*(char_end - 
1)))
-                    --char_end;
+                while (src < src_end && trim_set.test((unsigned char)*(src_end 
- 1)))
+                    --src_end;
 
-            res_data = reinterpret_cast<const UInt8 *>(char_data);
-            res_size = char_end - char_data;
+            dst = src;
+            dst_size = src_end - src;
         }
     };
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to