This is an automated email from the ASF dual-hosted git repository.
lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 7503ce8378 [GLUTEN-8283][CH] Eliminate CSE via `ExpressionParser`
(#8284)
7503ce8378 is described below
commit 7503ce8378e35e62d1a346ce0b2876371222b287
Author: lgbo <[email protected]>
AuthorDate: Tue Dec 31 14:51:11 2024 +0800
[GLUTEN-8283][CH] Eliminate CSE via `ExpressionParser` (#8284)
* eliminate cse during convert substrait expression to actions dag
* not match non-deterministic function
* dummpy column
* update
---
.../execution/GlutenFunctionValidateSuite.scala | 18 ++-
cpp-ch/local-engine/Parser/ExpressionParser.cpp | 166 ++++++++++++++++++++-
cpp-ch/local-engine/Parser/ExpressionParser.h | 14 +-
3 files changed, 183 insertions(+), 15 deletions(-)
diff --git
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
index 5dcba3b476..e1287c8b6d 100644
---
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
+++
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala
@@ -881,13 +881,25 @@ class GlutenFunctionValidateSuite extends
GlutenClickHouseWholeStageTransformerS
test("Test transform_keys/transform_values") {
val sql = """
+ |select id, sort_array(map_entries(m1)),
sort_array(map_entries(m2)) from(
+ |select id, first(m1) as m1, first(m2) as m2 from(
|select
+ | id,
| transform_keys(map_from_arrays(array(id+1, id+2, id+3),
- | array(1, id+2, 3)), (k, v) -> k + 1),
+ | array(1, id+2, 3)), (k, v) -> k + 1) as m1,
| transform_values(map_from_arrays(array(id+1, id+2, id+3),
- | array(1, id+2, 3)), (k, v) -> v + 1)
+ | array(1, id+2, 3)), (k, v) -> v + 1) as m2
|from range(10)
+ |) group by id
+ |) order by id
|""".stripMargin
- compareResultsAgainstVanillaSpark(sql, true, { _ => })
+
+ def checkProjects(df: DataFrame): Unit = {
+ val projects = collectWithSubqueries(df.queryExecution.executedPlan) {
+ case e: ProjectExecTransformer => e
+ }
+ assert(projects.size >= 1)
+ }
+ compareResultsAgainstVanillaSpark(sql, true, checkProjects, false)
}
}
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
index ab4d8650d2..7d41c78a3d 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp
@@ -15,6 +15,7 @@
* limitations under the License.
*/
#include "ExpressionParser.h"
+#include <Columns/ColumnSet.h>
#include <Core/Settings.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeDate32.h>
@@ -35,6 +36,7 @@
#include <Parser/ParserContext.h>
#include <Parser/SerializedPlanParser.h>
#include <Parser/TypeParser.h>
+#include <Poco/Logger.h>
#include <Common/BlockTypeUtils.h>
#include <Common/CHUtil.h>
#include <Common/logger_useful.h>
@@ -255,16 +257,31 @@ std::pair<DB::DataTypePtr, DB::Field>
LiteralParser::parse(const substrait::Expr
return std::make_pair(std::move(type), std::move(field));
}
-const DB::ActionsDAG::Node *
+const static std::string REUSE_COMMON_SUBEXPRESSION_CONF =
"reuse_cse_in_expression_parser";
+
+bool ExpressionParser::reuseCSE() const
+{
+ return
context->queryContext()->getConfigRef().getBool(REUSE_COMMON_SUBEXPRESSION_CONF,
true);
+}
+
+ExpressionParser::NodeRawConstPtr
ExpressionParser::addConstColumn(DB::ActionsDAG & actions_dag, const
DB::DataTypePtr type, const DB::Field & field) const
{
String name = toString(field).substr(0, 10);
name = getUniqueName(name);
- return
&actions_dag.addColumn(DB::ColumnWithTypeAndName(type->createColumnConst(1,
field), type, name));
+ const auto * res_node =
&actions_dag.addColumn(DB::ColumnWithTypeAndName(type->createColumnConst(1,
field), type, name));
+ if (reuseCSE())
+ {
+ // The new node, res_node will be remained in the ActionsDAG, but it
will not affect the execution.
+ // And it will be remove once `ActionsDAG::removeUnusedActions` is
called.
+ if (const auto * exists_node = findFirstStructureEqualNode(res_node,
actions_dag))
+ res_node = exists_node;
+ }
+ return res_node;
}
-const ActionsDAG::Node * ExpressionParser::parseExpression(ActionsDAG &
actions_dag, const substrait::Expression & rel) const
+ExpressionParser::NodeRawConstPtr ExpressionParser::parseExpression(ActionsDAG
& actions_dag, const substrait::Expression & rel) const
{
switch (rel.rex_type_case())
{
@@ -448,9 +465,7 @@ const ActionsDAG::Node *
ExpressionParser::parseExpression(ActionsDAG & actions_
DB::MutableColumnPtr elem_column = elem_type->createColumn();
elem_column->reserve(options_len);
for (int i = 0; i < options_len; ++i)
- {
elem_column->insert(options_type_and_field[i].second);
- }
auto name = getUniqueName("__set");
ColumnWithTypeAndName elem_block{std::move(elem_column),
elem_type, name};
@@ -604,7 +619,7 @@ ExpressionParser::parseFunctionArguments(DB::ActionsDAG &
actions_dag, const sub
return parsed_args;
}
-const DB::ActionsDAG::Node *
+ExpressionParser::NodeRawConstPtr
ExpressionParser::parseFunction(const substrait::Expression_ScalarFunction &
func, DB::ActionsDAG & actions_dag, bool add_to_output) const
{
auto function_signature = getFunctionNameInSignature(func);
@@ -615,7 +630,7 @@ ExpressionParser::parseFunction(const
substrait::Expression_ScalarFunction & fun
return function_node;
}
-const DB::ActionsDAG::Node * ExpressionParser::toFunctionNode(
+ExpressionParser::NodeRawConstPtr ExpressionParser::toFunctionNode(
DB::ActionsDAG & actions_dag,
const String & ch_function_name,
const DB::ActionsDAG::NodeRawConstPtrs & args,
@@ -628,7 +643,19 @@ const DB::ActionsDAG::Node *
ExpressionParser::toFunctionNode(
std::string args_name = join(args, ',');
result_name = ch_function_name + "(" + args_name + ")";
}
- return &actions_dag.addFunction(function_builder, args, result_name);
+ const auto * res_node = &actions_dag.addFunction(function_builder, args,
result_name);
+ if (reuseCSE())
+ {
+ const auto * exists_node = findFirstStructureEqualNode(res_node,
actions_dag);
+ if (exists_node)
+ {
+ if (result_name_.empty() || result_name ==
exists_node->result_name)
+ res_node = exists_node;
+ else
+ res_node = &actions_dag.addAlias(*exists_node, result_name);
+ }
+ }
+ return res_node;
}
std::atomic<UInt64> ExpressionParser::unique_name_counter = 0;
@@ -843,4 +870,127 @@ ExpressionParser::parseJsonTuple(const
substrait::Expression_ScalarFunction & fu
}
return res_nodes;
}
+
+
+static bool isAllowedDataType(const DB::IDataType & data_type)
+{
+ DB::WhichDataType which(data_type);
+ if (which.isNullable())
+ {
+ const auto * null_type = typeid_cast<const DB::DataTypeNullable
*>(&data_type);
+ return isAllowedDataType(*(null_type->getNestedType()));
+ }
+ else if (which.isNumber() || which.isStringOrFixedString() ||
which.isDateOrDate32OrDateTimeOrDateTime64())
+ return true;
+ else if (which.isArray())
+ {
+ auto nested_type = typeid_cast<const DB::DataTypeArray
*>(&data_type)->getNestedType();
+ return isAllowedDataType(*nested_type);
+ }
+ else if (which.isTuple())
+ {
+ const auto * tuple_type = typeid_cast<const DB::DataTypeTuple
*>(&data_type);
+ for (const auto & nested_type : tuple_type->getElements())
+ if (!isAllowedDataType(*nested_type))
+ return false;
+ return true;
+ }
+ else if (which.isMap())
+ {
+ const auto * map_type = typeid_cast<const DB::DataTypeMap
*>(&data_type);
+ return isAllowedDataType(*(map_type->getKeyType())) &&
isAllowedDataType(*(map_type->getValueType()));
+ }
+
+ return false;
+}
+
+bool ExpressionParser::areEqualNodes(NodeRawConstPtr a, NodeRawConstPtr b)
+{
+ if (a == b)
+ return true;
+
+ if (a->type != b->type || !a->result_type->equals(*(b->result_type)) ||
a->children.size() != b->children.size()
+ || !a->isDeterministic() || !b->isDeterministic() ||
!isAllowedDataType(*(a->result_type)))
+ return false;
+
+ switch (a->type)
+ {
+ case DB::ActionsDAG::ActionType::INPUT: {
+ if (a->result_name != b->result_name)
+ return false;
+ break;
+ }
+ case DB::ActionsDAG::ActionType::ALIAS: {
+ if (a->result_name != b->result_name)
+ return false;
+ break;
+ }
+ case DB::ActionsDAG::ActionType::COLUMN: {
+ // dummpy columns cannot be compared
+ if (typeid_cast<const DB::ColumnSet *>(a->column.get()))
+ return a->result_name == b->result_name;
+ if (a->column->compareAt(0, 0, *(b->column), 1) != 0)
+ return false;
+ break;
+ }
+ case DB::ActionsDAG::ActionType::ARRAY_JOIN: {
+ return false;
+ }
+ case DB::ActionsDAG::ActionType::FUNCTION: {
+ if (!a->function_base->isDeterministic() ||
a->function_base->getName() != b->function_base->getName())
+ return false;
+
+ break;
+ }
+ default: {
+ LOG_WARNING(
+ getLogger("ExpressionParser"),
+ "Unknow node type. type:{}, data type:{}, result_name:{}",
+ a->type,
+ a->result_type->getName(),
+ a->result_name);
+ return false;
+ }
+ }
+
+ for (size_t i = 0; i < a->children.size(); ++i)
+ if (!areEqualNodes(a->children[i], b->children[i]))
+ return false;
+ LOG_TEST(
+ getLogger("ExpressionParser"),
+ "Nodes are equal:\ntype:{},data type:{},name:{}\ntype:{},data
type:{},name:{}",
+ a->type,
+ a->result_type->getName(),
+ a->result_name,
+ b->type,
+ b->result_type->getName(),
+ b->result_name);
+ return true;
+}
+
+// since each new node is added at the end of ActionsDAG::nodes, we expect to
find the previous node and the new node will be dropped later.
+ExpressionParser::NodeRawConstPtr
+ExpressionParser::findFirstStructureEqualNode(NodeRawConstPtr target, const
DB::ActionsDAG & actions_dag) const
+{
+ for (const auto & node : actions_dag.getNodes())
+ {
+ if (target == &node)
+ continue;
+
+ if (areEqualNodes(target, &node))
+ {
+ LOG_TEST(
+ getLogger("ExpressionParser"),
+ "Two nodes are equal:\ntype:{},data
type:{},name:{}\ntype:{},data type:{},name:{}",
+ target->type,
+ target->result_type->getName(),
+ target->result_name,
+ node.type,
+ node.result_type->getName(),
+ node.result_name);
+ return &node;
+ }
+ }
+ return nullptr;
+}
}
diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.h
b/cpp-ch/local-engine/Parser/ExpressionParser.h
index 06a80d756e..1e4a48282a 100644
--- a/cpp-ch/local-engine/Parser/ExpressionParser.h
+++ b/cpp-ch/local-engine/Parser/ExpressionParser.h
@@ -40,26 +40,27 @@ public:
class ExpressionParser
{
public:
+ using NodeRawConstPtr = const DB::ActionsDAG::Node *;
ExpressionParser(const std::shared_ptr<const ParserContext> & context_) :
context(context_) { }
~ExpressionParser() = default;
/// Append a counter-suffix to name
String getUniqueName(const String & name) const;
- const DB::ActionsDAG::Node * addConstColumn(DB::ActionsDAG & actions_dag,
const DB::DataTypePtr type, const DB::Field & field) const;
+ NodeRawConstPtr addConstColumn(DB::ActionsDAG & actions_dag, const
DB::DataTypePtr type, const DB::Field & field) const;
/// Parse expr and add an expression node in actions_dag
- const DB::ActionsDAG::Node * parseExpression(DB::ActionsDAG & actions_dag,
const substrait::Expression & expr) const;
+ NodeRawConstPtr parseExpression(DB::ActionsDAG & actions_dag, const
substrait::Expression & expr) const;
/// Build an actions dag that contains expressions. header is used as
input columns for the actions dag.
DB::ActionsDAG expressionsToActionsDAG(const
std::vector<substrait::Expression> & expressions, const DB::Block & header)
const;
// Parse func's arguments into actions dag, and return the node ptrs.
DB::ActionsDAG::NodeRawConstPtrs
parseFunctionArguments(DB::ActionsDAG & actions_dag, const
substrait::Expression_ScalarFunction & func) const;
- const DB::ActionsDAG::Node *
+ NodeRawConstPtr
parseFunction(const substrait::Expression_ScalarFunction & func,
DB::ActionsDAG & actions_dag, bool add_to_output = false) const;
// Add a new function node into the actions dag
- const DB::ActionsDAG::Node * toFunctionNode(
+ NodeRawConstPtr toFunctionNode(
DB::ActionsDAG & actions_dag,
const String & ch_function_name,
const DB::ActionsDAG::NodeRawConstPtrs & args,
@@ -77,11 +78,16 @@ private:
static std::atomic<UInt64> unique_name_counter;
std::shared_ptr<const ParserContext> context;
+ bool reuseCSE() const;
+
DB::ActionsDAG::NodeRawConstPtrs
parseArrayJoin(const substrait::Expression_ScalarFunction & func,
DB::ActionsDAG & actions_dag, bool position) const;
DB::ActionsDAG::NodeRawConstPtrs parseArrayJoinArguments(
const substrait::Expression_ScalarFunction & func, DB::ActionsDAG &
actions_dag, bool position, bool & is_map) const;
DB::ActionsDAG::NodeRawConstPtrs parseJsonTuple(const
substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag)
const;
+
+ static bool areEqualNodes(NodeRawConstPtr a, NodeRawConstPtr b);
+ NodeRawConstPtr findFirstStructureEqualNode(NodeRawConstPtr target, const
DB::ActionsDAG & actions_dag) const;
};
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]