This is an automated email from the ASF dual-hosted git repository.

philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 41b002f2e [VL] Fix wrong result for try_add (#5356)
41b002f2e is described below

commit 41b002f2e5d215f38455ab23233031bc460b9568
Author: Zhen Li <[email protected]>
AuthorDate: Thu Apr 11 18:18:39 2024 +0800

    [VL] Fix wrong result for try_add (#5356)
---
 .../backendsapi/velox/SparkPlanExecApiImpl.scala   | 43 +++++++++++++++++++++-
 .../execution/ScalarFunctionsValidateSuite.scala   |  8 ++++
 cpp/velox/substrait/SubstraitParser.cc             |  3 +-
 .../gluten/backendsapi/SparkPlanExecApi.scala      | 23 ++++++++++++
 .../gluten/expression/ExpressionConverter.scala    | 21 +++++++++++
 .../gluten/expression/ExpressionMappings.scala     |  1 +
 .../apache/gluten/expression/ExpressionNames.scala |  2 +
 .../org/apache/gluten/sql/shims/SparkShims.scala   |  4 ++
 .../gluten/sql/shims/spark34/Spark34Shims.scala    | 14 +++++++
 .../gluten/sql/shims/spark35/Spark35Shims.scala    | 14 +++++++
 10 files changed, 131 insertions(+), 2 deletions(-)

diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
index 715a2eeca..80b371872 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/SparkPlanExecApiImpl.scala
@@ -37,7 +37,7 @@ import 
org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, FlushableHas
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayFilter, 
Ascending, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, 
ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, 
IsNaN, LambdaFunction, Literal, Murmur3Hash, NamedExpression, NaNvl, 
PosExplode, Round, SortOrder, StringSplit, StringTrim, Uuid}
+import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayFilter, 
Ascending, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, 
ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, 
IsNaN, LambdaFunction, Literal, Murmur3Hash, NamedExpression, NaNvl, 
PosExplode, Round, SortOrder, StringSplit, StringTrim, TryEval, Uuid}
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
HLLAdapter}
 import org.apache.spark.sql.catalyst.optimizer.BuildSide
 import org.apache.spark.sql.catalyst.plans.JoinType
@@ -139,6 +139,47 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
       original)
   }
 
+  override def genTryAddTransformer(
+      substraitExprName: String,
+      left: ExpressionTransformer,
+      right: ExpressionTransformer,
+      original: TryEval): ExpressionTransformer = {
+    if (SparkShimLoader.getSparkShims.withAnsiEvalMode(original.child)) {
+      throw new GlutenNotSupportException(s"add with ansi mode is not 
supported")
+    }
+    original.child.dataType match {
+      case LongType | IntegerType | ShortType | ByteType =>
+      case _ => throw new GlutenNotSupportException(s"try_add is not 
supported")
+    }
+    // Offload to velox for only IntegralTypes.
+    GenericExpressionTransformer(
+      substraitExprName,
+      Seq(GenericExpressionTransformer(ExpressionNames.TRY_ADD, Seq(left, 
right), original)),
+      original)
+  }
+
+  override def genAddTransformer(
+      substraitExprName: String,
+      left: ExpressionTransformer,
+      right: ExpressionTransformer,
+      original: Add): ExpressionTransformer = {
+    if (SparkShimLoader.getSparkShims.withTryEvalMode(original)) {
+      original.dataType match {
+        case LongType | IntegerType | ShortType | ByteType =>
+        case _ => throw new GlutenNotSupportException(s"try_add is not 
supported")
+      }
+      // Offload to velox for only IntegralTypes.
+      GenericExpressionTransformer(
+        ExpressionMappings.expressionsMap(classOf[TryEval]),
+        Seq(GenericExpressionTransformer(ExpressionNames.TRY_ADD, Seq(left, 
right), original)),
+        original)
+    } else if (SparkShimLoader.getSparkShims.withAnsiEvalMode(original)) {
+      throw new GlutenNotSupportException(s"add with ansi mode is not 
supported")
+    } else {
+      GenericExpressionTransformer(substraitExprName, Seq(left, right), 
original)
+    }
+  }
+
   /** Transform map_entries to Substrait. */
   override def genMapEntriesTransformer(
       substraitExprName: String,
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
index 9cfe44a8c..bfae3f37b 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala
@@ -747,4 +747,12 @@ class ScalarFunctionsValidateSuite extends 
FunctionsValidateTest {
     }
   }
 
+  test("try_add") {
+    runQueryAndCompare(
+      "select try_add(cast(l_orderkey as int), 1), try_add(cast(l_orderkey as 
int), 2147483647)" +
+        " from lineitem") {
+      checkGlutenOperatorMatch[ProjectExecTransformer]
+    }
+  }
+
 }
diff --git a/cpp/velox/substrait/SubstraitParser.cc 
b/cpp/velox/substrait/SubstraitParser.cc
index 75f4c246b..d1d8b0e17 100644
--- a/cpp/velox/substrait/SubstraitParser.cc
+++ b/cpp/velox/substrait/SubstraitParser.cc
@@ -402,7 +402,8 @@ std::unordered_map<std::string, std::string> 
SubstraitParser::substraitVeloxFunc
     {"xxhash64", "xxhash64_with_seed"},
     {"modulus", "remainder"},
     {"date_format", "format_datetime"},
-    {"collect_set", "set_agg"}};
+    {"collect_set", "set_agg"},
+    {"try_add", "plus"}};
 
 const std::unordered_map<std::string, std::string> SubstraitParser::typeMap_ = 
{
     {"bool", "BOOLEAN"},
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index faa5e9fd5..e08905162 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -212,6 +212,29 @@ trait SparkPlanExecApi {
     GenericExpressionTransformer(substraitExprName, Seq(), original)
   }
 
+  def genTryAddTransformer(
+      substraitExprName: String,
+      left: ExpressionTransformer,
+      right: ExpressionTransformer,
+      original: TryEval): ExpressionTransformer = {
+    throw new GlutenNotSupportException("try_add is not supported")
+  }
+
+  def genTryAddTransformer(
+      substraitExprName: String,
+      child: ExpressionTransformer,
+      original: TryEval): ExpressionTransformer = {
+    throw new GlutenNotSupportException("try_eval is not supported")
+  }
+
+  def genAddTransformer(
+      substraitExprName: String,
+      left: ExpressionTransformer,
+      right: ExpressionTransformer,
+      original: Add): ExpressionTransformer = {
+    GenericExpressionTransformer(substraitExprName, Seq(left, right), original)
+  }
+
   /** Transform map_entries to Substrait. */
   def genMapEntriesTransformer(
       substraitExprName: String,
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index a72be3973..26295678a 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -583,6 +583,27 @@ object ExpressionConverter extends SQLConfHelper with 
Logging {
           replaceWithExpressionTransformerInternal(f.function, attributeSeq, 
expressionsMap),
           f
         )
+      case tryEval @ TryEval(a: Add) =>
+        BackendsApiManager.getSparkPlanExecApiInstance.genTryAddTransformer(
+          substraitExprName,
+          replaceWithExpressionTransformerInternal(a.left, attributeSeq, 
expressionsMap),
+          replaceWithExpressionTransformerInternal(a.right, attributeSeq, 
expressionsMap),
+          tryEval
+        )
+      case a: Add =>
+        BackendsApiManager.getSparkPlanExecApiInstance.genAddTransformer(
+          substraitExprName,
+          replaceWithExpressionTransformerInternal(a.left, attributeSeq, 
expressionsMap),
+          replaceWithExpressionTransformerInternal(a.right, attributeSeq, 
expressionsMap),
+          a
+        )
+      case tryEval: TryEval =>
+        // This is a placeholder to handle try_eval(other expressions).
+        BackendsApiManager.getSparkPlanExecApiInstance.genTryAddTransformer(
+          substraitExprName,
+          replaceWithExpressionTransformerInternal(tryEval.child, 
attributeSeq, expressionsMap),
+          tryEval
+        )
       case expr =>
         GenericExpressionTransformer(
           substraitExprName,
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
index 82a884379..ce410842b 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala
@@ -61,6 +61,7 @@ object ExpressionMappings {
     Sig[Not](NOT),
     Sig[IsNaN](IS_NAN),
     Sig[NaNvl](NANVL),
+    Sig[TryEval](TRY_EVAL),
 
     // SparkSQL String functions
     Sig[Ascii](ASCII),
diff --git 
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
 
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
index 6365e6c70..d0be6b599 100644
--- 
a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
+++ 
b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala
@@ -81,6 +81,8 @@ object ExpressionNames {
   final val NOT = "not"
   final val IS_NAN = "isnan"
   final val NANVL = "nanvl"
+  final val TRY_EVAL = "try"
+  final val TRY_ADD = "try_add"
 
   // SparkSQL String functions
   final val ASCII = "ascii"
diff --git 
a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala 
b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
index 852a02c47..de7a3d38f 100644
--- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
+++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala
@@ -195,4 +195,8 @@ trait SparkShims {
 
   def supportsRowBased(plan: SparkPlan): Boolean = !plan.supportsColumnar
 
+  def withTryEvalMode(expr: Expression): Boolean = false
+
+  def withAnsiEvalMode(expr: Expression): Boolean = false
+
 }
diff --git 
a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
 
b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
index 438913317..b667ead63 100644
--- 
a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
+++ 
b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala
@@ -361,4 +361,18 @@ class Spark34Shims extends SparkShims {
   }
 
   override def supportsRowBased(plan: SparkPlan): Boolean = 
plan.supportsRowBased
+
+  override def withTryEvalMode(expr: Expression): Boolean = {
+    expr match {
+      case a: Add => a.evalMode == EvalMode.TRY
+      case _ => false
+    }
+  }
+
+  override def withAnsiEvalMode(expr: Expression): Boolean = {
+    expr match {
+      case a: Add => a.evalMode == EvalMode.ANSI
+      case _ => false
+    }
+  }
 }
diff --git 
a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
 
b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
index 371a1982c..a67b75dcc 100644
--- 
a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
+++ 
b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala
@@ -319,4 +319,18 @@ class Spark35Shims extends SparkShims {
     null
 
   override def supportsRowBased(plan: SparkPlan): Boolean = 
plan.supportsRowBased
+
+  override def withTryEvalMode(expr: Expression): Boolean = {
+    expr match {
+      case a: Add => a.evalMode == EvalMode.TRY
+      case _ => false
+    }
+  }
+
+  override def withAnsiEvalMode(expr: Expression): Boolean = {
+    expr match {
+      case a: Add => a.evalMode == EvalMode.ANSI
+      case _ => false
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to