This is an automated email from the ASF dual-hosted git repository.

taiyangli pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 2eb38cfd4 minor refactors on expand (#6861)
2eb38cfd4 is described below

commit 2eb38cfd4ed040203f0b7bf98a94231dfc914f01
Author: 李扬 <[email protected]>
AuthorDate: Thu Aug 15 14:22:27 2024 +0800

    minor refactors on expand (#6861)
---
 cpp-ch/local-engine/Operator/ExpandStep.cpp      |  8 ++-
 cpp-ch/local-engine/Operator/ExpandTransform.cpp | 70 ++++++++++--------------
 cpp-ch/local-engine/Operator/ExpandTransorm.h    | 10 ++--
 3 files changed, 40 insertions(+), 48 deletions(-)

diff --git a/cpp-ch/local-engine/Operator/ExpandStep.cpp 
b/cpp-ch/local-engine/Operator/ExpandStep.cpp
index 9f56d9fd9..8770c4c40 100644
--- a/cpp-ch/local-engine/Operator/ExpandStep.cpp
+++ b/cpp-ch/local-engine/Operator/ExpandStep.cpp
@@ -52,16 +52,18 @@ ExpandStep::ExpandStep(const DB::DataStream & 
input_stream_, const ExpandField &
     output_header = getOutputStream().header;
 }
 
-DB::Block ExpandStep::buildOutputHeader(const DB::Block & , const ExpandField 
& project_set_exprs_)
+DB::Block ExpandStep::buildOutputHeader(const DB::Block &, const ExpandField & 
project_set_exprs_)
 {
     DB::ColumnsWithTypeAndName cols;
     const auto & types = project_set_exprs_.getTypes();
     const auto & names = project_set_exprs_.getNames();
 
+    chassert(names.size() == types.size());
+
     for (size_t i = 0; i < project_set_exprs_.getExpandCols(); ++i)
-        cols.push_back(DB::ColumnWithTypeAndName(types[i], names[i]));
+        cols.emplace_back(DB::ColumnWithTypeAndName(types[i], names[i]));
 
-    return DB::Block(cols);
+    return DB::Block(std::move(cols));
 }
 
 void ExpandStep::transformPipeline(DB::QueryPipelineBuilder & pipeline, const 
DB::BuildQueryPipelineSettings & /*settings*/)
diff --git a/cpp-ch/local-engine/Operator/ExpandTransform.cpp 
b/cpp-ch/local-engine/Operator/ExpandTransform.cpp
index f5787163c..5100ad070 100644
--- a/cpp-ch/local-engine/Operator/ExpandTransform.cpp
+++ b/cpp-ch/local-engine/Operator/ExpandTransform.cpp
@@ -15,19 +15,20 @@
  * limitations under the License.
  */
 #include <memory>
+#include <Poco/Logger.h>
 
 #include <Columns/ColumnNullable.h>
 #include <Columns/ColumnsNumber.h>
 #include <Columns/IColumn.h>
 #include <DataTypes/DataTypeNullable.h>
 #include <DataTypes/DataTypesNumber.h>
+#include <Interpreters/castColumn.h>
 #include <Processors/IProcessor.h>
-#include "ExpandTransorm.h"
-
-#include <Poco/Logger.h>
 #include <Common/Exception.h>
 #include <Common/logger_useful.h>
 
+#include "ExpandTransorm.h"
+
 namespace DB
 {
 namespace ErrorCodes
@@ -93,53 +94,42 @@ void ExpandTransform::work()
     if (expand_expr_iterator >= project_set_exprs.getExpandRows())
         throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, 
"expand_expr_iterator >= project_set_exprs.getExpandRows()");
 
-    const auto & original_cols = input_chunk.getColumns();
+    const auto & input_header = getInputs().front().getHeader();
+    const auto & input_columns = input_chunk.getColumns();
+    const auto & types = project_set_exprs.getTypes();
+    const auto & kinds = project_set_exprs.getKinds()[expand_expr_iterator];
+    const auto & fields = project_set_exprs.getFields()[expand_expr_iterator];
     size_t rows = input_chunk.getNumRows();
-    DB::Columns cols;
-    for (size_t j = 0; j < project_set_exprs.getExpandCols(); ++j)
+
+    DB::Columns columns(types.size());
+    for (size_t col_i = 0; col_i < types.size(); ++col_i)
     {
-        const auto & type = project_set_exprs.getTypes()[j];
-        const auto & kind = 
project_set_exprs.getKinds()[expand_expr_iterator][j];
-        const auto & field = 
project_set_exprs.getFields()[expand_expr_iterator][j];
+        const auto & type = types[col_i];
+        const auto & kind = kinds[col_i];
+        const auto & field = fields[col_i];
 
         if (kind == EXPAND_FIELD_KIND_SELECTION)
         {
-            const auto & original_col = original_cols.at(field.get<Int32>());
-            if (type->isNullable() == original_col->isNullable())
-            {
-                cols.push_back(original_col);
-            }
-            else if (type->isNullable() && !original_col->isNullable())
-            {
-                auto null_map = DB::ColumnUInt8::create(rows, 0);
-                auto col = DB::ColumnNullable::create(original_col, 
std::move(null_map));
-                cols.push_back(std::move(col));
-            }
-            else
-            {
-                throw DB::Exception(
-                    DB::ErrorCodes::LOGICAL_ERROR,
-                    "Miss match nullable, column {} is nullable, but type {} 
is not nullable",
-                    original_col->getName(),
-                    type->getName());
-            }
+            auto index = field.get<Int32>();
+            const auto & input_column = input_columns[index];
+
+            DB::ColumnWithTypeAndName input_arg;
+            input_arg.column = input_column;
+            input_arg.type = input_header.getByPosition(index).type;
+            /// input_column maybe non-Nullable
+            columns[col_i] = DB::castColumn(input_arg, type);
         }
-        else if (field.isNull())
+        else if (kind == EXPAND_FIELD_KIND_LITERAL)
         {
-            // Add null column
-            auto null_map = DB::ColumnUInt8::create(rows, 1);
-            auto nested_type = DB::removeNullable(type);
-            auto col = 
DB::ColumnNullable::create(nested_type->createColumn()->cloneResized(rows), 
std::move(null_map));
-            cols.push_back(std::move(col));
+            /// Add const column with field value
+            auto column = type->createColumnConst(rows, field);
+            columns[col_i] = column;
         }
         else
-        {
-            // Add constant column: gid, gpos, etc.
-            auto col = type->createColumnConst(rows, field);
-            cols.push_back(col->convertToFullColumnIfConst());
-        }
+            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknown 
ExpandFieldKind {}", magic_enum::enum_name(kind));
     }
-    output_chunk = DB::Chunk(cols, rows);
+
+    output_chunk = DB::Chunk(std::move(columns), rows);
     has_output = true;
 
     ++expand_expr_iterator;
diff --git a/cpp-ch/local-engine/Operator/ExpandTransorm.h 
b/cpp-ch/local-engine/Operator/ExpandTransorm.h
index 90bdf3dc1..f315ca5db 100644
--- a/cpp-ch/local-engine/Operator/ExpandTransorm.h
+++ b/cpp-ch/local-engine/Operator/ExpandTransorm.h
@@ -15,21 +15,21 @@
  * limitations under the License.
  */
 #pragma once
-#include <set>
-#include <vector>
+
 #include <Core/Block.h>
 #include <Parser/ExpandField.h>
 #include <Processors/Chunk.h>
 #include <Processors/IProcessor.h>
 #include <Processors/Port.h>
+
 namespace local_engine
 {
 // For handling substrait expand node.
 // The implementation in spark for groupingsets/rollup/cube is different from 
Clickhouse.
-// We have to ways to support groupingsets/rollup/cube
-// - rewrite the substrait plan in local engine and reuse the implementation 
of clickhouse. This
+// We have two ways to support groupingsets/rollup/cube
+// - Rewrite the substrait plan in local engine and reuse the implementation 
of clickhouse. This
 //   may be more complex.
-// - implement new transform to do the expandation. It's more simple, but may 
suffer some performance
+// - Implement new transform to do the expandation. It's simpler, but may 
suffer some performance
 //   issues. We try this first.
 class ExpandTransform : public DB::IProcessor
 {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to