This is an automated email from the ASF dual-hosted git repository.
changchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 6671dd0b5 [GLUTEN-6989][CH] Support RTrim with const source column
(#6992)
6671dd0b5 is described below
commit 6671dd0b5cebf2e4660b94dce4038005f259af7d
Author: Wenzheng Liu <[email protected]>
AuthorDate: Thu Aug 29 10:04:10 2024 +0800
[GLUTEN-6989][CH] Support RTrim with const source column (#6992)
* [GLUTEN-6989][CH] Support RTrim with const source column
* [GLUTEN-6989][CH] Use bitmap256 to trim
* [GLUTEN-6989][CH] Fix comments
* [GLUTEN-6989][CH] Fix core dump
* [GLUTEN-6989][CH] Remove condition when trim col contains null
---
.../GlutenClickhouseStringFunctionsSuite.scala | 38 +++++
.../local-engine/Functions/SparkFunctionTrim.cpp | 173 ++++++++++-----------
2 files changed, 117 insertions(+), 94 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala
index 98c0c2b35..e40b293ea 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseStringFunctionsSuite.scala
@@ -49,6 +49,44 @@ class GlutenClickhouseStringFunctionsSuite extends
GlutenClickHouseWholeStageTra
}
}
+ test("GLUTEN-6989: rtrim support source column const") {
+ withTable("trim") {
+ sql("create table trim(trim_col String, src_col String) using parquet")
+ sql("""
+ |insert into trim values
+ | ('bAa', 'a'),('bba', 'b'),('abcdef', 'abcd'),
+ | (null, '123'),('123', null), ('', 'aaa'), ('bbb', '')
+ |""".stripMargin)
+
+ val sql0 = "select rtrim('aba', 'a') from trim order by src_col"
+ val sql1 = "select rtrim(trim_col, src_col) from trim order by src_col"
+ val sql2 = "select rtrim(trim_col, 'cCBbAa') from trim order by src_col"
+ val sql3 = "select rtrim(trim_col, '') from trim order by src_col"
+ val sql4 = "select rtrim('', 'AAA') from trim order by src_col"
+ val sql5 = "select rtrim('', src_col) from trim order by src_col"
+ val sql6 = "select rtrim('ab', src_col) from trim order by src_col"
+
+ runQueryAndCompare(sql0) { _ => }
+ runQueryAndCompare(sql1) { _ => }
+ runQueryAndCompare(sql2) { _ => }
+ runQueryAndCompare(sql3) { _ => }
+ runQueryAndCompare(sql4) { _ => }
+ runQueryAndCompare(sql5) { _ => }
+ runQueryAndCompare(sql6) { _ => }
+
+ // test other trim functions
+ val sql7 = "SELECT trim(LEADING trim_col FROM src_col) from trim"
+ val sql8 = "SELECT trim(LEADING trim_col FROM 'NSB') from trim"
+ val sql9 = "SELECT trim(TRAILING trim_col FROM src_col) from trim"
+ val sql10 = "SELECT trim(TRAILING trim_col FROM '') from trim"
+ runQueryAndCompare(sql7) { _ => }
+ runQueryAndCompare(sql8) { _ => }
+ runQueryAndCompare(sql9) { _ => }
+ runQueryAndCompare(sql10) { _ => }
+
+ }
+ }
+
test("GLUTEN-5897: fix regexp_extract with bracket") {
withTable("regexp_extract_bracket") {
sql("create table regexp_extract_bracket(a String) using parquet")
diff --git a/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp
b/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp
index 88ed3f635..fbbd944c0 100644
--- a/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp
+++ b/cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp
@@ -106,125 +106,110 @@ namespace
ColumnPtr
executeImpl(const ColumnsWithTypeAndName & arguments, const
DataTypePtr & /*result_type*/, size_t input_rows_count) const override
{
- const ColumnString * src_str_col =
checkAndGetColumn<ColumnString>(arguments[0].column.get());
- if (!src_str_col)
- throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of
function {} must be String", getName());
+ const ColumnString * src_col =
checkAndGetColumn<ColumnString>(arguments[0].column.get());
+ const ColumnConst * src_const_col =
checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
+ const ColumnString * trim_col =
checkAndGetColumn<ColumnString>(arguments[1].column.get());
+ const ColumnConst * trim_const_col =
checkAndGetColumnConst<ColumnString>(arguments[1].column.get());
+
+ String src_const_str;
+ String trim_const_str;
+ if (src_const_col)
+ src_const_str = src_const_col->getValue<String>();
+ if (trim_const_col)
+ {
+ trim_const_str = trim_const_col->getValue<String>();
+ if (trim_const_str.empty())
+ {
+ return arguments[0].column;
+ }
+ }
+ // If both arguments are constants, it will be simplified to a
constant. Skipped here.
- if (const auto * trim_const_str_col =
checkAndGetColumnConst<ColumnString>(arguments[1].column.get()))
- {
- String trim_str = trim_const_str_col->getValue<String>();
- if (trim_str.empty())
- return src_str_col->cloneResized(input_rows_count);
+ auto res_col = ColumnString::create();
+ ColumnString::Chars & res_data = res_col->getChars();
+ ColumnString::Offsets & res_offsets = res_col->getOffsets();
+ res_offsets.resize_exact(input_rows_count);
- auto res_col = ColumnString::create();
- res_col->reserve(input_rows_count);
- executeVector(src_str_col->getChars(),
src_str_col->getOffsets(), res_col->getChars(), res_col->getOffsets(),
trim_str);
+ // Source column is constant and trim column is not constant
+ if (src_const_col)
+ {
+ res_data.reserve_exact(src_const_str.size() *
input_rows_count);
+ for (size_t row = 0; row < input_rows_count; ++row)
+ {
+ StringRef trim_str_ref = trim_col->getDataAt(row);
+ std::unique_ptr<std::bitset<256>> trim_set =
buildTrimSet(trim_str_ref.data, trim_str_ref.size);
+ executeRow(src_const_str.c_str(), src_const_str.size(),
res_data, res_offsets, row, *trim_set);
+ }
return std::move(res_col);
}
- else if (const auto * trim_str_col =
checkAndGetColumn<ColumnString>(arguments[1].column.get()))
+
+ // Source column is not constant and trim column is constant
+ if (trim_const_col)
{
- auto res_col = ColumnString::create();
- res_col->reserve(input_rows_count);
-
- executeVector(
- src_str_col->getChars(),
- src_str_col->getOffsets(),
- res_col->getChars(),
- res_col->getOffsets(),
- trim_str_col->getChars(),
- trim_str_col->getOffsets());
+ res_data.reserve_exact(src_col->getChars().size());
+ std::unique_ptr<std::bitset<256>> trim_set =
buildTrimSet(trim_const_str.c_str(), trim_const_str.size());
+ for (size_t row = 0; row < input_rows_count; ++row)
+ {
+ StringRef src_str_ref = src_col->getDataAt(row);
+ executeRow(src_str_ref.data, src_str_ref.size, res_data,
res_offsets, row, *trim_set);
+ }
return std::move(res_col);
}
- throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Second argument of
function {} must be String or Const String", getName());
+ // Both columns are not constant
+ res_data.reserve(src_col->getChars().size());
+ for (size_t row = 0; row < input_rows_count; ++row)
+ {
+ StringRef src_str_ref = src_col->getDataAt(row);
+ StringRef trim_str_ref = trim_col->getDataAt(row);
+ std::unique_ptr<std::bitset<256>> trim_set =
buildTrimSet(trim_str_ref.data, trim_str_ref.size);
+ executeRow(src_str_ref.data, src_str_ref.size, res_data,
res_offsets, row, *trim_set);
+ }
+ return std::move(res_col);
}
private:
- void executeVector(
- const ColumnString::Chars & data,
- const ColumnString::Offsets & offsets,
+ void executeRow(
+ const char * src,
+ size_t src_size,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets,
- const String & trim_str) const
+ size_t row,
+ const std::bitset<256> & trim_set) const
{
- res_data.reserve_exact(data.size());
-
- size_t rows = offsets.size();
- res_offsets.resize_exact(rows);
-
- size_t prev_offset = 0;
- size_t res_offset = 0;
-
- const UInt8 * start;
- size_t length;
- std::unordered_set<char> trim_set(trim_str.begin(),
trim_str.end());
- for (size_t i = 0; i < rows; ++i)
- {
- trim(reinterpret_cast<const UInt8 *>(&data[prev_offset]),
offsets[i] - prev_offset - 1, start, length, trim_set);
- res_data.resize_exact(res_data.size() + length + 1);
- memcpySmallAllowReadWriteOverflow15(&res_data[res_offset],
start, length);
- res_offset += length + 1;
- res_data[res_offset - 1] = '\0';
-
- res_offsets[i] = res_offset;
- prev_offset = offsets[i];
- }
+ const char * dst;
+ size_t dst_size;
+ trim(src, src_size, dst, dst_size, trim_set);
+ size_t res_offset = row > 0 ? res_offsets[row - 1] : 0;
+ res_data.resize_exact(res_data.size() + dst_size + 1);
+ memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], dst,
dst_size);
+ res_offset += dst_size + 1;
+ res_data[res_offset - 1] = '\0';
+ res_offsets[row] = res_offset;
}
- void executeVector(
- const ColumnString::Chars & data,
- const ColumnString::Offsets & offsets,
- ColumnString::Chars & res_data,
- ColumnString::Offsets & res_offsets,
- const ColumnString::Chars & trim_data,
- const ColumnString::Offsets & trim_offsets) const
+ std::unique_ptr<std::bitset<256>> buildTrimSet(const char* data, const
size_t size) const
{
- res_data.reserve_exact(data.size());
-
- size_t rows = offsets.size();
- res_offsets.resize_exact(rows);
-
- size_t prev_offset = 0;
- size_t prev_trim_str_offset = 0;
- size_t res_offset = 0;
-
- const UInt8 * start;
- size_t length;
-
- for (size_t i = 0; i < rows; ++i)
- {
- std::unordered_set<char> trim_set(
- &trim_data[prev_trim_str_offset],
&trim_data[prev_trim_str_offset] + trim_offsets[i] - prev_trim_str_offset - 1);
-
- trim(reinterpret_cast<const UInt8 *>(&data[prev_offset]),
offsets[i] - prev_offset - 1, start, length, trim_set);
- res_data.resize_exact(res_data.size() + length + 1);
- memcpySmallAllowReadWriteOverflow15(&res_data[res_offset],
start, length);
- res_offset += length + 1;
- res_data[res_offset - 1] = '\0';
-
- res_offsets[i] = res_offset;
- prev_offset = offsets[i];
- prev_trim_str_offset = trim_offsets[i];
- }
+ auto trim_set = std::make_unique<std::bitset<256>>();
+ for (size_t i = 0; i < size; ++i)
+ trim_set->set((unsigned char)data[i]);
+ return trim_set;
}
- void
- trim(const UInt8 * data, size_t size, const UInt8 *& res_data, size_t
& res_size, const std::unordered_set<char> & trim_set) const
+ void trim(const char * src, const size_t src_size, const char *& dst,
size_t & dst_size, const std::bitset<256> & trim_set) const
{
- const char * char_data = reinterpret_cast<const char *>(data);
- const char * char_end = char_data + size;
-
+ const char * src_end = src + src_size;
if constexpr (TrimMode::trim_left)
- while (char_data < char_end && trim_set.contains(*char_data))
- ++char_data;
+ while (src < src_end && trim_set.test((unsigned char)*src))
+ ++src;
if constexpr (TrimMode::trim_right)
- while (char_data < char_end && trim_set.contains(*(char_end -
1)))
- --char_end;
+ while (src < src_end && trim_set.test((unsigned char)*(src_end
- 1)))
+ --src_end;
- res_data = reinterpret_cast<const UInt8 *>(char_data);
- res_size = char_end - char_data;
+ dst = src;
+ dst_size = src_end - src;
}
};
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]