lwz9103 commented on code in PR #6992:
URL: https://github.com/apache/incubator-gluten/pull/6992#discussion_r1732526297


##########
cpp-ch/local-engine/Functions/SparkFunctionTrim.cpp:
##########
@@ -106,125 +106,114 @@ namespace
         ColumnPtr
         executeImpl(const ColumnsWithTypeAndName & arguments, const 
DataTypePtr & /*result_type*/, size_t input_rows_count) const override
         {
-            const ColumnString * src_str_col = 
checkAndGetColumn<ColumnString>(arguments[0].column.get());
-            if (!src_str_col)
-                throw Exception(ErrorCodes::ILLEGAL_COLUMN, "First argument of 
function {} must be String", getName());
+            const ColumnString * src_col = 
checkAndGetColumn<ColumnString>(arguments[0].column.get());
+            const ColumnConst * src_const_col = 
checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
+            const ColumnString * trim_col = 
checkAndGetColumn<ColumnString>(arguments[1].column.get());
+            const ColumnConst * trim_const_col = 
checkAndGetColumnConst<ColumnString>(arguments[1].column.get());
+
+            String src_const_str;
+            String trim_const_str;
+            if (src_const_col)
+                src_const_str = src_const_col->getValue<String>();
+            if (trim_const_col)
+                trim_const_str = trim_const_col->getValue<String>();
+            if (trim_const_col && trim_const_str.empty()) {
+                return arguments[0].column;
+            }
 
+            // If both arguments are constants, it will be simplified to a 
constant. Skipped here.
 
-            if (const auto * trim_const_str_col = 
checkAndGetColumnConst<ColumnString>(arguments[1].column.get()))
-            {
-                String trim_str = trim_const_str_col->getValue<String>();
-                if (trim_str.empty())
-                    return src_str_col->cloneResized(input_rows_count);
+            auto res_col = ColumnString::create();
+            ColumnString::Chars & res_data = res_col->getChars();
+            ColumnString::Offsets & res_offsets = res_col->getOffsets();
+            res_offsets.resize_exact(input_rows_count);
 
-                auto res_col = ColumnString::create();
-                res_col->reserve(input_rows_count);
-                executeVector(src_str_col->getChars(), 
src_str_col->getOffsets(), res_col->getChars(), res_col->getOffsets(), 
trim_str);
+            // Source column is constant and trim column is not constant
+            if (src_const_col)
+            {
+                res_data.reserve_exact(src_const_str.size() * 
input_rows_count);
+                for (size_t row = 0; row < input_rows_count; ++row)
+                {
+                    StringRef trim_str_ref = trim_col->getDataAt(row);
+                    std::unique_ptr<std::bitset<256>> trim_set = 
buildTrimSet(trim_str_ref.toString());
+                    executeRow(src_const_str.c_str(), src_const_str.size(), 
res_data, res_offsets, row, trim_set);
+                }
                 return std::move(res_col);
             }
-            else if (const auto * trim_str_col = 
checkAndGetColumn<ColumnString>(arguments[1].column.get()))
+
+            // Source column is not constant and trim column is constant
+            if (trim_const_col)
             {
-                auto res_col = ColumnString::create();
-                res_col->reserve(input_rows_count);
-
-                executeVector(
-                    src_str_col->getChars(),
-                    src_str_col->getOffsets(),
-                    res_col->getChars(),
-                    res_col->getOffsets(),
-                    trim_str_col->getChars(),
-                    trim_str_col->getOffsets());
+                res_data.reserve_exact(src_col->getChars().size());
+                std::unique_ptr<std::bitset<256>> trim_set = 
buildTrimSet(trim_const_str);
+                for (size_t row = 0; row < input_rows_count; ++row)
+                {
+                    StringRef src_str_ref = src_col->getDataAt(row);
+                    executeRow(src_str_ref.data, src_str_ref.size, res_data, 
res_offsets, row, trim_set);
+                }
                 return std::move(res_col);
             }
 
-            throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Second argument of 
function {} must be String or Const String", getName());
+            // Both columns are not constant
+            res_data.reserve(src_col->getChars().size());
+            for (size_t row = 0; row < input_rows_count; ++row)
+            {
+                StringRef src_str_ref = src_col->getDataAt(row);
+                StringRef trim_str_ref = trim_col->getDataAt(row);
+                std::unique_ptr<std::bitset<256>> trim_set = 
buildTrimSet(trim_str_ref.toString());
+                executeRow(src_str_ref.data, src_str_ref.size, res_data, 
res_offsets, row, trim_set);
+            }
+            return std::move(res_col);
         }
 
     private:
-        void executeVector(
-            const ColumnString::Chars & data,
-            const ColumnString::Offsets & offsets,
+        void executeRow(
+            const char * src,
+            size_t src_size,
             ColumnString::Chars & res_data,
             ColumnString::Offsets & res_offsets,
-            const String & trim_str) const
+            size_t & row,
+            const std::unique_ptr<std::bitset<256>> & trim_set) const
         {
-            res_data.reserve_exact(data.size());
-
-            size_t rows = offsets.size();
-            res_offsets.resize_exact(rows);
-
-            size_t prev_offset = 0;
-            size_t res_offset = 0;
-
-            const UInt8 * start;
-            size_t length;
-            std::unordered_set<char> trim_set(trim_str.begin(), 
trim_str.end());
-            for (size_t i = 0; i < rows; ++i)
-            {
-                trim(reinterpret_cast<const UInt8 *>(&data[prev_offset]), 
offsets[i] - prev_offset - 1, start, length, trim_set);
-                res_data.resize_exact(res_data.size() + length + 1);
-                memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], 
start, length);
-                res_offset += length + 1;
-                res_data[res_offset - 1] = '\0';
-
-                res_offsets[i] = res_offset;
-                prev_offset = offsets[i];
-            }
+            const char * dst;
+            size_t dst_size;
+            trim(src, src_size, dst, dst_size, trim_set);
+            size_t res_offset = row > 0 ? res_offsets[row - 1] : 0;
+            res_data.resize_exact(res_data.size() + dst_size + 1);
+            memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], dst, 
dst_size);
+            res_offset += dst_size + 1;
+            res_data[res_offset - 1] = '\0';
+            res_offsets[row] = res_offset;
         }
 
-        void executeVector(
-            const ColumnString::Chars & data,
-            const ColumnString::Offsets & offsets,
-            ColumnString::Chars & res_data,
-            ColumnString::Offsets & res_offsets,
-            const ColumnString::Chars & trim_data,
-            const ColumnString::Offsets & trim_offsets) const
+        std::unique_ptr<std::bitset<256>> buildTrimSet(const String& trim_str) 
const
         {
-            res_data.reserve_exact(data.size());
-
-            size_t rows = offsets.size();
-            res_offsets.resize_exact(rows);
-
-            size_t prev_offset = 0;
-            size_t prev_trim_str_offset = 0;
-            size_t res_offset = 0;
-
-            const UInt8 * start;
-            size_t length;
-
-            for (size_t i = 0; i < rows; ++i)
-            {
-                std::unordered_set<char> trim_set(
-                    &trim_data[prev_trim_str_offset], 
&trim_data[prev_trim_str_offset] + trim_offsets[i] - prev_trim_str_offset - 1);
-
-                trim(reinterpret_cast<const UInt8 *>(&data[prev_offset]), 
offsets[i] - prev_offset - 1, start, length, trim_set);
-                res_data.resize_exact(res_data.size() + length + 1);
-                memcpySmallAllowReadWriteOverflow15(&res_data[res_offset], 
start, length);
-                res_offset += length + 1;
-                res_data[res_offset - 1] = '\0';
-
-                res_offsets[i] = res_offset;
-                prev_offset = offsets[i];
-                prev_trim_str_offset = trim_offsets[i];
-            }
+            auto trim_set = std::make_unique<std::bitset<256>>();
+            for (unsigned char i : trim_str)
+                trim_set->set(i);
+            return trim_set;
         }
 
-        void
-        trim(const UInt8 * data, size_t size, const UInt8 *& res_data, size_t 
& res_size, const std::unordered_set<char> & trim_set) const
+        void trim(const char * src, size_t src_size, const char *& dst, size_t 
& dst_size, const std::unique_ptr<std::bitset<256>> & trim_set) const
         {
-            const char * char_data = reinterpret_cast<const char *>(data);
-            const char * char_end = char_data + size;
+            if (!trim_set || trim_set->none())

Review Comment:
   removed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to