This is an automated email from the ASF dual-hosted git repository.

lwz9103 pushed a commit to branch liquid
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git

commit 6ee01ba8fda61f9579c36e8e78571edeef3777ef
Author: Wenzheng Liu <[email protected]>
AuthorDate: Fri Jun 28 16:13:27 2024 +0800

    support KylinSplitPart (#20)
    
    (cherry picked from commit 9ccb077de3648e8ac833ab455a5a63e176e00fa9)
---
 .../execution/kap/GlutenKapExpressionsSuite.scala  | 103 ++++++++++++++++++++-
 .../sql/catalyst/expressions/KapExpressions.scala  |  62 +++++++++++++
 .../gluten/KapExpressionsTransformer.scala         |  15 ++-
 .../scalar_function_parser/kylinSplitPart.cpp      |  84 +++++++++++++++++
 4 files changed, 261 insertions(+), 3 deletions(-)

diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala
index bc3ec1e184..c9237e7483 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/kap/GlutenKapExpressionsSuite.scala
@@ -22,7 +22,7 @@ import org.apache.gluten.utils.UTSystemParameters
 import org.apache.spark.SparkConf
 import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase
-import org.apache.spark.sql.catalyst.expressions.{Expression, 
KapSubtractMonths, Sum0, YMDintBetween}
+import org.apache.spark.sql.catalyst.expressions.{Expression, 
KapSubtractMonths, KylinSplitPart, Sum0, YMDintBetween}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.types._
 
@@ -57,6 +57,7 @@ class GlutenKapExpressionsSuite
     registerSparkUdf[Sum0]("sum0")
     registerSparkUdf[KapSubtractMonths]("kap_month_between")
     registerSparkUdf[YMDintBetween]("_ymdint_between")
+    registerSparkUdf[KylinSplitPart]("kylin_split_part")
   }
 
   def createKylinTables(): Unit = {
@@ -317,4 +318,104 @@ class GlutenKapExpressionsSuite
       })
   }
 
+  test("test kylin_split_part") {
+    val sql0 =
+      s"""
+         |select kylin_split_part(cast(cal_dt as string), '-', 0)
+         |from test_kylin_fact
+         |where cal_dt <= date'2012-03-01'
+         |group by cal_dt order by cal_dt
+         |""".stripMargin
+    compareResultsAgainstVanillaSpark(
+      sql0,
+      compareResult = true,
+      df => {
+        assert(df.head().getString(0) == null)
+      })
+
+    val sql1 =
+      s"""
+         |select kylin_split_part(cast(cal_dt as string), '-', 1)
+         |from test_kylin_fact
+         |where cal_dt <= date'2012-03-01'
+         |group by cal_dt order by cal_dt
+         |""".stripMargin
+    compareResultsAgainstVanillaSpark(
+      sql1,
+      compareResult = true,
+      df => {
+        assert(df.head().getString(0) == "2012")
+      })
+
+    val sql2 =
+      s"""
+         |select kylin_split_part(cast(cal_dt as string), '-', 2)
+         |from test_kylin_fact
+         |where cal_dt <= date'2012-03-01'
+         |group by cal_dt order by cal_dt
+         |""".stripMargin
+    compareResultsAgainstVanillaSpark(
+      sql2,
+      compareResult = true,
+      df => {
+        assert(df.head().getString(0) == "01")
+      })
+
+    val sql3 =
+      s"""
+         |select kylin_split_part(cast(cal_dt as string), '-', 3)
+         |from test_kylin_fact
+         |where cal_dt <= date'2012-03-01'
+         |group by cal_dt order by cal_dt
+         |""".stripMargin
+    compareResultsAgainstVanillaSpark(
+      sql3,
+      compareResult = true,
+      df => {
+        assert(df.head().getString(0) == "01")
+      })
+
+    val sql4 =
+      s"""
+         |select kylin_split_part(cast(cal_dt as string), '-', 4)
+         |from test_kylin_fact
+         |where cal_dt <= date'2012-03-01'
+         |group by cal_dt order by cal_dt
+         |""".stripMargin
+    compareResultsAgainstVanillaSpark(
+      sql4,
+      compareResult = true,
+      df => {
+        assert(df.head().getString(0) == null)
+      })
+
+    val sql5 =
+      s"""
+         |select kylin_split_part(cast(cal_dt as string), '-', -1)
+         |from test_kylin_fact
+         |where cal_dt <= date'2012-03-01'
+         |group by cal_dt order by cal_dt
+         |""".stripMargin
+    compareResultsAgainstVanillaSpark(
+      sql5,
+      compareResult = true,
+      df => {
+        assert(df.head().getString(0) == "01")
+      })
+
+    val sql6 =
+      s"""
+         |select kylin_split_part(cast(cal_dt as string), '\\\\d{1,2}', -2)
+         |from test_kylin_fact
+         |where cal_dt <= date'2012-03-01'
+         |group by cal_dt order by cal_dt
+         |""".stripMargin
+    compareResultsAgainstVanillaSpark(
+      sql6,
+      compareResult = true,
+      df => {
+        assert(df.head().getString(0) == "-")
+      })
+  }
+
 }
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala
 
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala
index da4fc2c644..497ebf0e1a 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/KapExpressions.scala
@@ -177,4 +177,66 @@ case class YMDintBetween(first: Expression, second: 
Expression)
     super.legacyWithNewChildren(newChildren)
   }
 }
+
+case class KylinSplitPart(left: Expression, mid: Expression, right: Expression)
+  extends TernaryExpression
+  with ExpectsInputTypes {
+
+  override def dataType: DataType = left.dataType
+
+  override def nullable: Boolean = true
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, 
IntegerType)
+
+  override def first: Expression = left
+
+  override def second: Expression = mid
+
+  override def third: Expression = right
+
+  override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): 
Any = {
+    SplitPartImpl.evaluate(input1.toString, input2.toString, 
input3.asInstanceOf[Int])
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
+    val ta = SplitPartImpl.getClass.getName.stripSuffix("$")
+    nullSafeCodeGen(
+      ctx,
+      ev,
+      (arg1, arg2, arg3) => {
+        s"""
+          org.apache.spark.unsafe.types.UTF8String result = 
$ta.evaluate($arg1.toString(), $arg2.toString(), $arg3);
+          if (result == null) {
+            ${ev.isNull} = true;
+          } else {
+            ${ev.value} = result;
+          }
+        """
+      }
+    )
+  }
+
+  override protected def withNewChildrenInternal(
+      newFirst: Expression,
+      newSecond: Expression,
+      newThird: Expression): Expression = {
+    val newChildren = Seq(newFirst, newSecond, newThird)
+    super.legacyWithNewChildren(newChildren)
+  }
+}
+
+object SplitPartImpl {
+
+  def evaluate(str: String, rex: String, index: Int): UTF8String = {
+    val parts = str.split(rex)
+    if (index - 1 < parts.length && index > 0) {
+      UTF8String.fromString(parts(index - 1))
+    } else if (index < 0 && Math.abs(index) <= parts.length) {
+      UTF8String.fromString(parts(parts.length + index))
+    } else {
+      null
+    }
+  }
+}
+
 // scalastyle:on line.size.limit
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala
 
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala
index 97a0d63504..52715c49eb 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/spark/sql/catalyst/expressions/gluten/KapExpressionsTransformer.scala
@@ -17,7 +17,7 @@
 package org.apache.spark.sql.catalyst.expressions.gluten
 
 import org.apache.gluten.exception.GlutenNotSupportException
-import org.apache.gluten.expression._
+import org.apache.gluten.expression.{Sig, _}
 import org.apache.gluten.extension.ExpressionExtensionTrait
 
 import org.apache.spark.sql.catalyst.expressions._
@@ -33,7 +33,8 @@ case class KapExpressionsTransformer() extends 
ExpressionExtensionTrait {
   def expressionSigList: Seq[Sig] = Seq(
     Sig[Sum0]("sum0"),
     Sig[KapSubtractMonths]("kap_month_between"),
-    Sig[YMDintBetween]("kap_ymd_int_between")
+    Sig[YMDintBetween]("kap_ymd_int_between"),
+    Sig[KylinSplitPart]("kylin_split_part")
   )
 
   override def replaceWithExtensionExpressionTransformer(
@@ -60,6 +61,16 @@ case class KapExpressionsTransformer() extends 
ExpressionExtensionTrait {
         ),
         kapYmdIntBetween
       )
+    case kylinSplitPart: KylinSplitPart if 
kylinSplitPart.second.isInstanceOf[Literal] =>
+      new GenericExpressionTransformer(
+        substraitExprName,
+        Seq(
+          
ExpressionConverter.replaceWithExpressionTransformer(kylinSplitPart.first, 
attributeSeq),
+          LiteralTransformer(kylinSplitPart.second.asInstanceOf[Literal]),
+          
ExpressionConverter.replaceWithExpressionTransformer(kylinSplitPart.third, 
attributeSeq)
+        ),
+        kylinSplitPart
+      )
     case _ =>
       throw new UnsupportedOperationException(
         s"${expr.getClass} or $expr is not currently supported.")
diff --git 
a/cpp-ch/local-engine/Parser/scalar_function_parser/kylinSplitPart.cpp 
b/cpp-ch/local-engine/Parser/scalar_function_parser/kylinSplitPart.cpp
new file mode 100644
index 0000000000..c4c0ceab7b
--- /dev/null
+++ b/cpp-ch/local-engine/Parser/scalar_function_parser/kylinSplitPart.cpp
@@ -0,0 +1,84 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+#include <Parser/FunctionParser.h>
+
+namespace DB
+{
+namespace ErrorCodes
+{
+extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
+}
+}
+
+namespace local_engine
+{
+
+class FunctionParserKylinSplitPart : public FunctionParser
+{
+public:
+   explicit FunctionParserKylinSplitPart(SerializedPlanParser * plan_parser_) 
: FunctionParser(plan_parser_) {}
+   ~FunctionParserKylinSplitPart() override = default;
+
+   static constexpr auto name = "kylin_split_part";
+
+   String getName() const override { return name; }
+
+   const ActionsDAG::Node * parse(
+       const substrait::Expression_ScalarFunction & substrait_func,
+       ActionsDAGPtr & actions_dag) const override
+   {
+       /*
+            parse kylin_split_part(str, rex, idx) as
+            parts = splitByRegexp(rex, str)
+            if (abs(idx) > 0 && abs(idx) <= parts.length)
+                arrayElement(parts, idx)
+            else
+                null
+        */
+       auto parsed_args = parseFunctionArguments(substrait_func, "", 
actions_dag);
+       if (parsed_args.size() != 3)
+           throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, 
"Function {} requires exactly three arguments", getName());
+
+       const auto * str_node = parsed_args[0];
+       const auto * rex_node = parsed_args[1];
+       const auto * idx_node = parsed_args[2];
+
+       const auto * parts_node = toFunctionNode(actions_dag, "splitByRegexp", 
{rex_node, str_node});
+       const auto * length_node = toFunctionNode(actions_dag, "length", 
{parts_node});
+
+       // abs(idx) > 0 && abs(idx) <= parts.length
+       const auto * zero_const_node = addColumnToActionsDAG(actions_dag, 
std::make_shared<DataTypeUInt8>(), 0);
+       const auto * abs_idx_node = toFunctionNode(actions_dag, "abs", 
{idx_node});
+       const auto * abs_idx_gt_zero_node = toFunctionNode(actions_dag, 
"greater", {abs_idx_node, zero_const_node});
+       const auto * abs_idx_le_length_node = toFunctionNode(actions_dag, 
"lessOrEquals", {abs_idx_node, length_node});
+       const auto * condition_node = toFunctionNode(actions_dag, "and", 
{abs_idx_gt_zero_node, abs_idx_le_length_node});
+       const auto * then_node = toFunctionNode(actions_dag, "arrayElement", 
{parts_node, idx_node});
+
+       // NULL
+       auto result_type = std::make_shared<DataTypeString>();
+       const auto * null_const_node = addColumnToActionsDAG(actions_dag, 
makeNullable(result_type), Field());
+
+       // if
+       const auto * result_node = toFunctionNode(actions_dag, "if", 
{condition_node, then_node, null_const_node});
+
+       return convertNodeTypeIfNeeded(substrait_func, result_node, 
actions_dag);
+   }
+};
+
+static FunctionParserRegister<FunctionParserKylinSplitPart> 
register_kylin_split_part;
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to