This is an automated email from the ASF dual-hosted git repository. lwz9103 pushed a commit to branch liquid in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
commit 6ee01ba8fda61f9579c36e8e78571edeef3777ef Author: Wenzheng Liu <[email protected]> AuthorDate: Fri Jun 28 16:13:27 2024 +0800 support KylinSplitPart (#20) (cherry picked from commit 9ccb077de3648e8ac833ab455a5a63e176e00fa9) --- .../execution/kap/GlutenKapExpressionsSuite.scala | 103 ++++++++++++++++++++- .../sql/catalyst/expressions/KapExpressions.scala | 62 +++++++++++++ .../gluten/KapExpressionsTransformer.scala | 15 ++- .../scalar_function_parser/kylinSplitPart.cpp | 84 +++++++++++++++++ 4 files changed, 261 insertions(+), 3 deletions(-) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala index bc3ec1e184..c9237e7483 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala @@ -22,7 +22,7 @@ import org.apache.gluten.utils.UTSystemParameters import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase -import org.apache.spark.sql.catalyst.expressions.{Expression, KapSubtractMonths, Sum0, YMDintBetween} +import org.apache.spark.sql.catalyst.expressions.{Expression, KapSubtractMonths, KylinSplitPart, Sum0, YMDintBetween} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.types._ @@ -57,6 +57,7 @@ class GlutenKapExpressionsSuite registerSparkUdf[Sum0]("sum0") registerSparkUdf[KapSubtractMonths]("kap_month_between") registerSparkUdf[YMDintBetween]("_ymdint_between") + registerSparkUdf[KylinSplitPart]("kylin_split_part") } def createKylinTables(): Unit = { @@ -317,4 +318,104 @@ class GlutenKapExpressionsSuite }) } + test("test kylin_split_part") { + val sql0 = + s""" + |select kylin_split_part(cast(cal_dt as string), '-', 0) + |from test_kylin_fact + |where cal_dt <= date'2012-03-01' + |group by cal_dt order by cal_dt + |""".stripMargin + compareResultsAgainstVanillaSpark( + sql0, + compareResult = true, + df => { + assert(df.head().getString(0) == null) + }) + + val sql1 = + s""" + |select kylin_split_part(cast(cal_dt as string), '-', 1) + |from test_kylin_fact + |where cal_dt <= date'2012-03-01' + |group by cal_dt order by cal_dt + |""".stripMargin + compareResultsAgainstVanillaSpark( + sql1, + compareResult = true, + df => { + assert(df.head().getString(0) == "2012") + }) + + val sql2 = + s""" + |select kylin_split_part(cast(cal_dt as string), '-', 2) + |from test_kylin_fact + |where cal_dt <= date'2012-03-01' + |group by cal_dt order by cal_dt + |""".stripMargin + compareResultsAgainstVanillaSpark( + sql2, + compareResult = true, + df => { + assert(df.head().getString(0) == "01") + }) + + val sql3 = + s""" + |select kylin_split_part(cast(cal_dt as string), '-', 3) + |from test_kylin_fact + |where cal_dt <= date'2012-03-01' + |group by cal_dt order by cal_dt + |""".stripMargin + compareResultsAgainstVanillaSpark( + sql3, + compareResult = true, + df => { + assert(df.head().getString(0) == "01") + }) + + val sql4 = + s""" + |select kylin_split_part(cast(cal_dt as string), '-', 4) + |from test_kylin_fact + |where cal_dt <= date'2012-03-01' + |group by cal_dt order by cal_dt + |""".stripMargin + compareResultsAgainstVanillaSpark( + sql4, + compareResult = true, + df => { + assert(df.head().getString(0) == null) + }) + + val sql5 = + s""" + |select kylin_split_part(cast(cal_dt as string), '-', -1) + |from test_kylin_fact + |where cal_dt <= date'2012-03-01' + |group by cal_dt order by cal_dt + |""".stripMargin + compareResultsAgainstVanillaSpark( + sql5, + compareResult = true, + df => { + assert(df.head().getString(0) == "01") + }) + + val sql6 = + s""" + |select kylin_split_part(cast(cal_dt as string), '\\\\d{1,2}', -2) + |from test_kylin_fact + |where cal_dt <= date'2012-03-01' + |group by cal_dt order by cal_dt + |""".stripMargin + compareResultsAgainstVanillaSpark( + sql6, + compareResult = true, + df => { + assert(df.head().getString(0) == "-") + }) + } + } diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala index da4fc2c644..497ebf0e1a 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala @@ -177,4 +177,66 @@ case class YMDintBetween(first: Expression, second: Expression) super.legacyWithNewChildren(newChildren) } } + +case class KylinSplitPart(left: Expression, mid: Expression, right: Expression) + extends TernaryExpression + with ExpectsInputTypes { + + override def dataType: DataType = left.dataType + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) + + override def first: Expression = left + + override def second: Expression = mid + + override def third: Expression = right + + override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { + SplitPartImpl.evaluate(input1.toString, input2.toString, input3.asInstanceOf[Int]) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val ta = SplitPartImpl.getClass.getName.stripSuffix("$") + nullSafeCodeGen( + ctx, + ev, + (arg1, arg2, arg3) => { + s""" + org.apache.spark.unsafe.types.UTF8String result = $ta.evaluate($arg1.toString(), $arg2.toString(), $arg3); + if (result == null) { + ${ev.isNull} = true; + } else { + ${ev.value} = result; + } + """ + } + ) + } + + override protected def withNewChildrenInternal( + newFirst: Expression, + newSecond: Expression, + newThird: Expression): Expression = { + val newChildren = Seq(newFirst, newSecond, newThird) + super.legacyWithNewChildren(newChildren) + } +} + +object SplitPartImpl { + + def evaluate(str: String, rex: String, index: Int): UTF8String = { + val parts = str.split(rex) + if (index - 1 < parts.length && index > 0) { + UTF8String.fromString(parts(index - 1)) + } else if (index < 0 && Math.abs(index) <= parts.length) { + UTF8String.fromString(parts(parts.length + index)) + } else { + null + } + } +} + // scalastyle:on line.size.limit diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala index 97a0d63504..52715c49eb 100644 --- a/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala +++ b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.gluten import org.apache.gluten.exception.GlutenNotSupportException -import org.apache.gluten.expression._ +import org.apache.gluten.expression.{Sig, _} import org.apache.gluten.extension.ExpressionExtensionTrait import org.apache.spark.sql.catalyst.expressions._ @@ -33,7 +33,8 @@ case class KapExpressionsTransformer() extends ExpressionExtensionTrait { def expressionSigList: Seq[Sig] = Seq( Sig[Sum0]("sum0"), Sig[KapSubtractMonths]("kap_month_between"), - Sig[YMDintBetween]("kap_ymd_int_between") + Sig[YMDintBetween]("kap_ymd_int_between"), + Sig[KylinSplitPart]("kylin_split_part") ) override def replaceWithExtensionExpressionTransformer( @@ -60,6 +61,16 @@ case class KapExpressionsTransformer() extends ExpressionExtensionTrait { ), kapYmdIntBetween ) + case kylinSplitPart: KylinSplitPart if kylinSplitPart.second.isInstanceOf[Literal] => + new GenericExpressionTransformer( + substraitExprName, + Seq( + ExpressionConverter.replaceWithExpressionTransformer(kylinSplitPart.first, attributeSeq), + LiteralTransformer(kylinSplitPart.second.asInstanceOf[Literal]), + ExpressionConverter.replaceWithExpressionTransformer(kylinSplitPart.third, attributeSeq) + ), + kylinSplitPart + ) case _ => throw new UnsupportedOperationException( s"${expr.getClass} or $expr is not currently supported.") diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/kylinSplitPart.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/kylinSplitPart.cpp new file mode 100644 index 0000000000..c4c0ceab7b --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/kylinSplitPart.cpp @@ -0,0 +1,84 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include <Parser/FunctionParser.h> + +namespace DB +{ +namespace ErrorCodes +{ +extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ + +class FunctionParserKylinSplitPart : public FunctionParser +{ +public: + explicit FunctionParserKylinSplitPart(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~FunctionParserKylinSplitPart() override = default; + + static constexpr auto name = "kylin_split_part"; + + String getName() const override { return name; } + + const ActionsDAG::Node * parse( + const substrait::Expression_ScalarFunction & substrait_func, + ActionsDAGPtr & actions_dag) const override + { + /* + parse kylin_split_part(str, rex, idx) as + parts = splitByRegexp(rex, str) + if (abs(idx) > 0 && abs(idx) <= parts.length) + arrayElement(parts, idx) + else + null + */ + auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); + if (parsed_args.size() != 3) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly three arguments", getName()); + + const auto * str_node = parsed_args[0]; + const auto * rex_node = parsed_args[1]; + const auto * idx_node = parsed_args[2]; + + const auto * parts_node = toFunctionNode(actions_dag, "splitByRegexp", {rex_node, str_node}); + const auto * length_node = toFunctionNode(actions_dag, "length", {parts_node}); + + // abs(idx) > 0 && abs(idx) <= parts.length + const auto * zero_const_node = addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeUInt8>(), 0); + const auto * abs_idx_node = toFunctionNode(actions_dag, "abs", {idx_node}); + const auto * abs_idx_gt_zero_node = toFunctionNode(actions_dag, "greater", {abs_idx_node, zero_const_node}); + const auto * abs_idx_le_length_node = toFunctionNode(actions_dag, "lessOrEquals", {abs_idx_node, length_node}); + const auto * condition_node = toFunctionNode(actions_dag, "and", {abs_idx_gt_zero_node, abs_idx_le_length_node}); + const auto * then_node = toFunctionNode(actions_dag, "arrayElement", {parts_node, idx_node}); + + // NULL + auto result_type = std::make_shared<DataTypeString>(); + const auto * null_const_node = addColumnToActionsDAG(actions_dag, makeNullable(result_type), Field()); + + // if + const auto * result_node = toFunctionNode(actions_dag, "if", {condition_node, then_node, null_const_node}); + + return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag); + } +}; + +static FunctionParserRegister<FunctionParserKylinSplitPart> register_kylin_split_part; +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
