taiyang-li commented on code in PR #6992:
URL: https://github.com/apache/incubator-gluten/pull/6992#discussion_r1732077370
##########
cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp:
##########
@@ -106,125 +106,118 @@ namespace
ColumnPtr
executeImpl(const ColumnsWithTypeAndName & arguments, const
DataTypePtr & /*result_type*/, size_t input_rows_count) const override
{
- const ColumnString * src_str_col =
checkAndGetColumn<ColumnString>(arguments[0].column.get());
- if (!src_str_col)
- throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of
function {} must be String", getName());
+ const ColumnString * src_col =
checkAndGetColumn<ColumnString>(arguments[0].column.get());
+ const ColumnConst * src_const_col =
checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
+ const ColumnString * trim_col =
checkAndGetColumn<ColumnString>(arguments[1].column.get());
+ const ColumnConst * trim_const_col =
checkAndGetColumnConst<ColumnString>(arguments[1].column.get());
+
+ String src_const_str;
+ String trim_const_str;
+ if (src_const_col)
+ src_const_str = src_const_col->getValue<String>();
+ if (trim_const_col)
+ {
+ trim_const_str = trim_const_col->getValue<String>();
+ if (trim_const_str.empty())
+ {
+ return arguments[0].column;
+ }
+ }
+ // If both arguments are constants, it will be simplified to a
constant. Skipped here.
- if (const auto * trim_const_str_col =
checkAndGetColumnConst<ColumnString>(arguments[1].column.get()))
- {
- String trim_str = trim_const_str_col->getValue<String>();
- if (trim_str.empty())
- return src_str_col->cloneResized(input_rows_count);
+ auto res_col = ColumnString::create();
+ ColumnString::Chars & res_data = res_col->getChars();
+ ColumnString::Offsets & res_offsets = res_col->getOffsets();
+ res_offsets.resize_exact(input_rows_count);
- auto res_col = ColumnString::create();
- res_col->reserve(input_rows_count);
- executeVector(src_str_col->getChars(),
src_str_col->getOffsets(), res_col->getChars(), res_col->getOffsets(),
trim_str);
+ // Source column is constant and trim column is not constant
+ if (src_const_col)
+ {
+ res_data.reserve_exact(src_const_str.size() *
input_rows_count);
+ for (size_t row = 0; row < input_rows_count; ++row)
+ {
+ StringRef trim_str_ref = trim_col->getDataAt(row);
+ std::unique_ptr<std::bitset<256>> trim_set =
buildTrimSet(trim_str_ref.data, trim_str_ref.size);
+ executeRow(src_const_str.c_str(), src_const_str.size(),
res_data, res_offsets, row, *trim_set);
+ }
return std::move(res_col);
}
- else if (const auto * trim_str_col =
checkAndGetColumn<ColumnString>(arguments[1].column.get()))
+
+ // Source column is not constant and trim column is constant
+ if (trim_const_col)
{
- auto res_col = ColumnString::create();
- res_col->reserve(input_rows_count);
-
- executeVector(
- src_str_col->getChars(),
- src_str_col->getOffsets(),
- res_col->getChars(),
- res_col->getOffsets(),
- trim_str_col->getChars(),
- trim_str_col->getOffsets());
+ res_data.reserve_exact(src_col->getChars().size());
+ std::unique_ptr<std::bitset<256>> trim_set =
buildTrimSet(trim_const_str.c_str(), trim_const_str.size());
+ for (size_t row = 0; row < input_rows_count; ++row)
+ {
+ StringRef src_str_ref = src_col->getDataAt(row);
+ executeRow(src_str_ref.data, src_str_ref.size, res_data,
res_offsets, row, *trim_set);
+ }
return std::move(res_col);
}
- throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Second argument of
function {} must be String or Const String", getName());
+ // Both columns are not constant
+ res_data.reserve(src_col->getChars().size());
+ for (size_t row = 0; row < input_rows_count; ++row)
+ {
+ StringRef src_str_ref = src_col->getDataAt(row);
+ StringRef trim_str_ref = trim_col->getDataAt(row);
+ std::unique_ptr<std::bitset<256>> trim_set =
buildTrimSet(trim_str_ref.data, trim_str_ref.size);
+ executeRow(src_str_ref.data, src_str_ref.size, res_data,
res_offsets, row, *trim_set);
+ }
+ return std::move(res_col);
}
private:
- void executeVector(
- const ColumnString::Chars & data,
- const ColumnString::Offsets & offsets,
+ void executeRow(
+ const char * src,
+ size_t src_size,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets,
- const String & trim_str) const
+ size_t row,
+ const std::bitset<256> & trim_set) const
{
- res_data.reserve_exact(data.size());
-
- size_t rows = offsets.size();
- res_offsets.resize_exact(rows);
-
- size_t prev_offset = 0;
- size_t res_offset = 0;
-
- const UInt8 * start;
- size_t length;
- std::unordered_set<char> trim_set(trim_str.begin(),
trim_str.end());
- for (size_t i = 0; i < rows; ++i)
- {
- trim(reinterpret_cast<const UInt8 *>(&data[prev_offset]),
offsets[i] - prev_offset - 1, start, length, trim_set);
- res_data.resize_exact(res_data.size() + length + 1);
- memcpySmallAllowReadWriteOverflow15(&res_data[res_offset],
start, length);
- res_offset += length + 1;
- res_data[res_offset - 1] = '\0';
-
- res_offsets[i] = res_offset;
- prev_offset = offsets[i];
- }
+ const char * dst;
+ size_t dst_size;
+ trim(src, src_size, dst, dst_size, trim_set);
+ size_t res_offset = row > 0 ? res_offsets[row - 1] : 0;
+ res_data.resize_exact(res_data.size() + dst_size + 1);
+ memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], dst,
dst_size);
+ res_offset += dst_size + 1;
+ res_data[res_offset - 1] = '\0';
+ res_offsets[row] = res_offset;
}
- void executeVector(
- const ColumnString::Chars & data,
- const ColumnString::Offsets & offsets,
- ColumnString::Chars & res_data,
- ColumnString::Offsets & res_offsets,
- const ColumnString::Chars & trim_data,
- const ColumnString::Offsets & trim_offsets) const
+ std::unique_ptr<std::bitset<256>> buildTrimSet(const char* data, const
size_t size) const
{
- res_data.reserve_exact(data.size());
-
- size_t rows = offsets.size();
- res_offsets.resize_exact(rows);
-
- size_t prev_offset = 0;
- size_t prev_trim_str_offset = 0;
- size_t res_offset = 0;
-
- const UInt8 * start;
- size_t length;
-
- for (size_t i = 0; i < rows; ++i)
- {
- std::unordered_set<char> trim_set(
- &trim_data[prev_trim_str_offset],
&trim_data[prev_trim_str_offset] + trim_offsets[i] - prev_trim_str_offset - 1);
-
- trim(reinterpret_cast<const UInt8 *>(&data[prev_offset]),
offsets[i] - prev_offset - 1, start, length, trim_set);
- res_data.resize_exact(res_data.size() + length + 1);
- memcpySmallAllowReadWriteOverflow15(&res_data[res_offset],
start, length);
- res_offset += length + 1;
- res_data[res_offset - 1] = '\0';
-
- res_offsets[i] = res_offset;
- prev_offset = offsets[i];
- prev_trim_str_offset = trim_offsets[i];
- }
+ auto trim_set = std::make_unique<std::bitset<256>>();
+ for (size_t i = 0; i < size; ++i)
+ trim_set->set((unsigned char)data[i]);
+ return trim_set;
}
- void
- trim(const UInt8 * data, size_t size, const UInt8 *& res_data, size_t
& res_size, const std::unordered_set<char> & trim_set) const
+ void trim(const char * src, const size_t src_size, const char *& dst,
size_t & dst_size, const std::bitset<256> & trim_set) const
{
- const char * char_data = reinterpret_cast<const char *>(data);
- const char * char_end = char_data + size;
+ // If trim column is not constant and contains null value
+ if (trim_set.none())
Review Comment:
trim column doesn't contains null value as long as
`useDefaultImplementationForNulls()` returns true. see more details in function
`IExecutableFunction::defaultImplementationForNulls`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]