This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 7503ce8378 [GLUTEN-8283][CH] Eliminate CSE via `ExpressionParser` 
(#8284)
7503ce8378 is described below

commit 7503ce8378e35e62d1a346ce0b2876371222b287
Author: lgbo <[email protected]>
AuthorDate: Tue Dec 31 14:51:11 2024 +0800

    [GLUTEN-8283][CH] Eliminate CSE via `ExpressionParser` (#8284)
    
    * eliminate cse during convert substrait expression to actions dag
    
    * not match non-deterministic function
    
    * dummpy column
    
    * update
---
 .../execution/GlutenFunctionValidateSuite.scala    |  18 ++-
 cpp-ch/local-engine/Parser/ExpressionParser.cpp    | 166 ++++++++++++++++++++-
 cpp-ch/local-engine/Parser/ExpressionParser.h      |  14 +-
 3 files changed, 183 insertions(+), 15 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 5dcba3b476..e1287c8b6d 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -881,13 +881,25 @@ class GlutenFunctionValidateSuite extends 
GlutenClickHouseWholeStageTransformerS
 
   test("Test transform_keys/transform_values") {
     val sql = """
+                |select id, sort_array(map_entries(m1)), 
sort_array(map_entries(m2)) from(
+                |select id, first(m1) as m1, first(m2) as m2 from(
                 |select
+                |  id,
                 |  transform_keys(map_from_arrays(array(id+1, id+2, id+3),
-                |    array(1, id+2, 3)), (k, v) -> k + 1),
+                |    array(1, id+2, 3)), (k, v) -> k + 1) as m1,
                 |  transform_values(map_from_arrays(array(id+1, id+2, id+3),
-                |    array(1, id+2, 3)), (k, v) -> v + 1)
+                |    array(1, id+2, 3)), (k, v) -> v + 1) as m2
                 |from range(10)
+                |) group by id
+                |) order by id
                 |""".stripMargin
-    compareResultsAgainstVanillaSpark(sql, true, { _ => })
+
+    def checkProjects(df: DataFrame): Unit = {
+      val projects = collectWithSubqueries(df.queryExecution.executedPlan) {
+        case e: ProjectExecTransformer => e
+      }
+      assert(projects.size >= 1)
+    }
+    compareResultsAgainstVanillaSpark(sql, true, checkProjects, false)
   }
 }
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp 
b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
index ab4d8650d2..7d41c78a3d 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
@@ -15,6 +15,7 @@
  * limitations under the License.
  */
 #include "ExpressionParser.h"
+#include <Columns/ColumnSet.h>
 #include <Core/Settings.h>
 #include <DataTypes/DataTypeArray.h>
 #include <DataTypes/DataTypeDate32.h>
@@ -35,6 +36,7 @@
 #include <Parser/ParserContext.h>
 #include <Parser/SerializedPlanParser.h>
 #include <Parser/TypeParser.h>
+#include <Poco/Logger.h>
 #include <Common/BlockTypeUtils.h>
 #include <Common/CHUtil.h>
 #include <Common/logger_useful.h>
@@ -255,16 +257,31 @@ std::pair<DB::DataTypePtr, DB::Field> 
LiteralParser::parse(const substrait::Expr
     return std::make_pair(std::move(type), std::move(field));
 }
 
-const DB::ActionsDAG::Node *
+const static std::string REUSE_COMMON_SUBEXPRESSION_CONF = 
"reuse_cse_in_expression_parser";
+
+bool ExpressionParser::reuseCSE() const
+{
+    return 
context->queryContext()->getConfigRef().getBool(REUSE_COMMON_SUBEXPRESSION_CONF,
 true);
+}
+
+ExpressionParser::NodeRawConstPtr
 ExpressionParser::addConstColumn(DB::ActionsDAG & actions_dag, const 
DB::DataTypePtr type, const DB::Field & field) const
 {
     String name = toString(field).substr(0, 10);
     name = getUniqueName(name);
-    return 
&actions_dag.addColumn(DB::ColumnWithTypeAndName(type->createColumnConst(1, 
field), type, name));
+    const auto * res_node = 
&actions_dag.addColumn(DB::ColumnWithTypeAndName(type->createColumnConst(1, 
field), type, name));
+    if (reuseCSE())
+    {
+        // The new node, res_node will be remained in the ActionsDAG, but it 
will not affect the execution.
+        // And it will be remove once `ActionsDAG::removeUnusedActions` is 
called.
+        if (const auto * exists_node = findFirstStructureEqualNode(res_node, 
actions_dag))
+            res_node = exists_node;
+    }
+    return res_node;
 }
 
 
-const ActionsDAG::Node * ExpressionParser::parseExpression(ActionsDAG & 
actions_dag, const substrait::Expression & rel) const
+ExpressionParser::NodeRawConstPtr ExpressionParser::parseExpression(ActionsDAG 
& actions_dag, const substrait::Expression & rel) const
 {
     switch (rel.rex_type_case())
     {
@@ -448,9 +465,7 @@ const ActionsDAG::Node * 
ExpressionParser::parseExpression(ActionsDAG & actions_
             DB::MutableColumnPtr elem_column = elem_type->createColumn();
             elem_column->reserve(options_len);
             for (int i = 0; i < options_len; ++i)
-            {
                 elem_column->insert(options_type_and_field[i].second);
-            }
             auto name = getUniqueName("__set");
             ColumnWithTypeAndName elem_block{std::move(elem_column), 
elem_type, name};
 
@@ -604,7 +619,7 @@ ExpressionParser::parseFunctionArguments(DB::ActionsDAG & 
actions_dag, const sub
     return parsed_args;
 }
 
-const DB::ActionsDAG::Node *
+ExpressionParser::NodeRawConstPtr
 ExpressionParser::parseFunction(const substrait::Expression_ScalarFunction & 
func, DB::ActionsDAG & actions_dag, bool add_to_output) const
 {
     auto function_signature = getFunctionNameInSignature(func);
@@ -615,7 +630,7 @@ ExpressionParser::parseFunction(const 
substrait::Expression_ScalarFunction & fun
     return function_node;
 }
 
-const DB::ActionsDAG::Node * ExpressionParser::toFunctionNode(
+ExpressionParser::NodeRawConstPtr ExpressionParser::toFunctionNode(
     DB::ActionsDAG & actions_dag,
     const String & ch_function_name,
     const DB::ActionsDAG::NodeRawConstPtrs & args,
@@ -628,7 +643,19 @@ const DB::ActionsDAG::Node * 
ExpressionParser::toFunctionNode(
         std::string args_name = join(args, ',');
         result_name = ch_function_name + "(" + args_name + ")";
     }
-    return &actions_dag.addFunction(function_builder, args, result_name);
+    const auto * res_node = &actions_dag.addFunction(function_builder, args, 
result_name);
+    if (reuseCSE())
+    {
+        const auto * exists_node = findFirstStructureEqualNode(res_node, 
actions_dag);
+        if (exists_node)
+        {
+            if (result_name_.empty() || result_name == 
exists_node->result_name)
+                res_node = exists_node;
+            else
+                res_node = &actions_dag.addAlias(*exists_node, result_name);
+        }
+    }
+    return res_node;
 }
 
 std::atomic<UInt64> ExpressionParser::unique_name_counter = 0;
@@ -843,4 +870,127 @@ ExpressionParser::parseJsonTuple(const 
substrait::Expression_ScalarFunction & fu
     }
     return res_nodes;
 }
+
+
+static bool isAllowedDataType(const DB::IDataType & data_type)
+{
+    DB::WhichDataType which(data_type);
+    if (which.isNullable())
+    {
+        const auto * null_type = typeid_cast<const DB::DataTypeNullable 
*>(&data_type);
+        return isAllowedDataType(*(null_type->getNestedType()));
+    }
+    else if (which.isNumber() || which.isStringOrFixedString() || 
which.isDateOrDate32OrDateTimeOrDateTime64())
+        return true;
+    else if (which.isArray())
+    {
+        auto nested_type = typeid_cast<const DB::DataTypeArray 
*>(&data_type)->getNestedType();
+        return isAllowedDataType(*nested_type);
+    }
+    else if (which.isTuple())
+    {
+        const auto * tuple_type = typeid_cast<const DB::DataTypeTuple 
*>(&data_type);
+        for (const auto & nested_type : tuple_type->getElements())
+            if (!isAllowedDataType(*nested_type))
+                return false;
+        return true;
+    }
+    else if (which.isMap())
+    {
+        const auto * map_type = typeid_cast<const DB::DataTypeMap 
*>(&data_type);
+        return isAllowedDataType(*(map_type->getKeyType())) && 
isAllowedDataType(*(map_type->getValueType()));
+    }
+
+    return false;
+}
+
+bool ExpressionParser::areEqualNodes(NodeRawConstPtr a, NodeRawConstPtr b)
+{
+    if (a == b)
+        return true;
+
+    if (a->type != b->type || !a->result_type->equals(*(b->result_type)) || 
a->children.size() != b->children.size()
+        || !a->isDeterministic() || !b->isDeterministic() || 
!isAllowedDataType(*(a->result_type)))
+        return false;
+
+    switch (a->type)
+    {
+        case DB::ActionsDAG::ActionType::INPUT: {
+            if (a->result_name != b->result_name)
+                return false;
+            break;
+        }
+        case DB::ActionsDAG::ActionType::ALIAS: {
+            if (a->result_name != b->result_name)
+                return false;
+            break;
+        }
+        case DB::ActionsDAG::ActionType::COLUMN: {
+            // dummpy columns cannot be compared
+            if (typeid_cast<const DB::ColumnSet *>(a->column.get()))
+                return a->result_name == b->result_name;
+            if (a->column->compareAt(0, 0, *(b->column), 1) != 0)
+                return false;
+            break;
+        }
+        case DB::ActionsDAG::ActionType::ARRAY_JOIN: {
+            return false;
+        }
+        case DB::ActionsDAG::ActionType::FUNCTION: {
+            if (!a->function_base->isDeterministic() || 
a->function_base->getName() != b->function_base->getName())
+                return false;
+
+            break;
+        }
+        default: {
+            LOG_WARNING(
+                getLogger("ExpressionParser"),
+                "Unknow node type. type:{}, data type:{}, result_name:{}",
+                a->type,
+                a->result_type->getName(),
+                a->result_name);
+            return false;
+        }
+    }
+
+    for (size_t i = 0; i < a->children.size(); ++i)
+        if (!areEqualNodes(a->children[i], b->children[i]))
+            return false;
+    LOG_TEST(
+        getLogger("ExpressionParser"),
+        "Nodes are equal:\ntype:{},data type:{},name:{}\ntype:{},data 
type:{},name:{}",
+        a->type,
+        a->result_type->getName(),
+        a->result_name,
+        b->type,
+        b->result_type->getName(),
+        b->result_name);
+    return true;
+}
+
+// since each new node is added at the end of ActionsDAG::nodes, we expect to 
find the previous node and the new node will be dropped later.
+ExpressionParser::NodeRawConstPtr
+ExpressionParser::findFirstStructureEqualNode(NodeRawConstPtr target, const 
DB::ActionsDAG & actions_dag) const
+{
+    for (const auto & node : actions_dag.getNodes())
+    {
+        if (target == &node)
+            continue;
+
+        if (areEqualNodes(target, &node))
+        {
+            LOG_TEST(
+                getLogger("ExpressionParser"),
+                "Two nodes are equal:\ntype:{},data 
type:{},name:{}\ntype:{},data type:{},name:{}",
+                target->type,
+                target->result_type->getName(),
+                target->result_name,
+                node.type,
+                node.result_type->getName(),
+                node.result_name);
+            return &node;
+        }
+    }
+    return nullptr;
+}
 }
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.h 
b/cpp-ch/local-engine/Parser/ExpressionParser.h
index 06a80d756e..1e4a48282a 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.h
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.h
@@ -40,26 +40,27 @@ public:
 class ExpressionParser
 {
 public:
+    using NodeRawConstPtr = const DB::ActionsDAG::Node *;
     ExpressionParser(const std::shared_ptr<const ParserContext> & context_) : 
context(context_) { }
     ~ExpressionParser() = default;
 
     /// Append a counter-suffix to name
     String getUniqueName(const String & name) const;
 
-    const DB::ActionsDAG::Node * addConstColumn(DB::ActionsDAG & actions_dag, 
const DB::DataTypePtr type, const DB::Field & field) const;
+    NodeRawConstPtr addConstColumn(DB::ActionsDAG & actions_dag, const 
DB::DataTypePtr type, const DB::Field & field) const;
 
     /// Parse expr and add an expression node in actions_dag
-    const DB::ActionsDAG::Node * parseExpression(DB::ActionsDAG & actions_dag, 
const substrait::Expression & expr) const;
+    NodeRawConstPtr parseExpression(DB::ActionsDAG & actions_dag, const 
substrait::Expression & expr) const;
     /// Build an actions dag that contains expressions. header is used as 
input columns for the actions dag.
     DB::ActionsDAG expressionsToActionsDAG(const 
std::vector<substrait::Expression> & expressions, const DB::Block & header) 
const;
 
     // Parse func's arguments into actions dag, and return the node ptrs.
     DB::ActionsDAG::NodeRawConstPtrs
     parseFunctionArguments(DB::ActionsDAG & actions_dag, const 
substrait::Expression_ScalarFunction & func) const;
-    const DB::ActionsDAG::Node *
+    NodeRawConstPtr
     parseFunction(const substrait::Expression_ScalarFunction & func, 
DB::ActionsDAG & actions_dag, bool add_to_output = false) const;
     // Add a new function node into the actions dag
-    const DB::ActionsDAG::Node * toFunctionNode(
+    NodeRawConstPtr toFunctionNode(
         DB::ActionsDAG & actions_dag,
         const String & ch_function_name,
         const DB::ActionsDAG::NodeRawConstPtrs & args,
@@ -77,11 +78,16 @@ private:
     static std::atomic<UInt64> unique_name_counter;
     std::shared_ptr<const ParserContext> context;
 
+    bool reuseCSE() const;
+
     DB::ActionsDAG::NodeRawConstPtrs
     parseArrayJoin(const substrait::Expression_ScalarFunction & func, 
DB::ActionsDAG & actions_dag, bool position) const;
     DB::ActionsDAG::NodeRawConstPtrs parseArrayJoinArguments(
         const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & 
actions_dag, bool position, bool & is_map) const;
 
     DB::ActionsDAG::NodeRawConstPtrs parseJsonTuple(const 
substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag) 
const;
+
+    static bool areEqualNodes(NodeRawConstPtr a, NodeRawConstPtr b);
+    NodeRawConstPtr findFirstStructureEqualNode(NodeRawConstPtr target, const 
DB::ActionsDAG & actions_dag) const;
 };
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to