This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new ddf6a6ff feat: Add support for TryCast expression in Spark 3.2 and 3.3 
(#416)
ddf6a6ff is described below

commit ddf6a6ffe34eb26f065397034618b9249131121b
Author: Vipul Vaibhaw <[email protected]>
AuthorDate: Thu May 16 05:25:40 2024 +0530

    feat: Add support for TryCast expression in Spark 3.2 and 3.3 (#416)
    
    * working on trycast
    
    * code refactor
    
    * compilation fix
    
    * bug fixes and supporting try_Cast
    
    * removing trycast var and comment
    
    * removing issue comment
    
    * adding comments
---
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 97 +++++++++++++---------
 .../scala/org/apache/comet/CometCastSuite.scala    | 22 ++---
 2 files changed, 64 insertions(+), 55 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 7238990a..cf7c86a9 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -604,6 +604,52 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
 
     def exprToProtoInternal(expr: Expression, inputs: Seq[Attribute]): 
Option[Expr] = {
       SQLConf.get
+
+      def handleCast(
+          child: Expression,
+          inputs: Seq[Attribute],
+          dt: DataType,
+          timeZoneId: Option[String],
+          actualEvalModeStr: String): Option[Expr] = {
+
+        val childExpr = exprToProtoInternal(child, inputs)
+        if (childExpr.isDefined) {
+          val castSupport =
+            CometCast.isSupported(child.dataType, dt, timeZoneId, 
actualEvalModeStr)
+
+          def getIncompatMessage(reason: Option[String]): String =
+            "Comet does not guarantee correct results for cast " +
+              s"from ${child.dataType} to $dt " +
+              s"with timezone $timeZoneId and evalMode $actualEvalModeStr" +
+              reason.map(str => s" ($str)").getOrElse("")
+
+          castSupport match {
+            case Compatible(_) =>
+              castToProto(timeZoneId, dt, childExpr, actualEvalModeStr)
+            case Incompatible(reason) =>
+              if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
+                logWarning(getIncompatMessage(reason))
+                castToProto(timeZoneId, dt, childExpr, actualEvalModeStr)
+              } else {
+                withInfo(
+                  expr,
+                  s"${getIncompatMessage(reason)}. To enable all incompatible 
casts, set " +
+                    s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
+                None
+              }
+            case Unsupported =>
+              withInfo(
+                expr,
+                s"Unsupported cast from ${child.dataType} to $dt " +
+                  s"with timezone $timeZoneId and evalMode $actualEvalModeStr")
+              None
+          }
+        } else {
+          withInfo(expr, child)
+          None
+        }
+      }
+
       expr match {
         case a @ Alias(_, _) =>
           val r = exprToProtoInternal(a.child, inputs)
@@ -617,50 +663,19 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           val value = cast.eval()
           exprToProtoInternal(Literal(value, dataType), inputs)
 
+        case UnaryExpression(child) if expr.prettyName == "trycast" =>
+          val timeZoneId = SQLConf.get.sessionLocalTimeZone
+          handleCast(child, inputs, expr.dataType, Some(timeZoneId), "TRY")
+
         case Cast(child, dt, timeZoneId, evalMode) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          if (childExpr.isDefined) {
-            val evalModeStr = if (evalMode.isInstanceOf[Boolean]) {
-              // Spark 3.2 & 3.3 has ansiEnabled boolean
-              if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY"
-            } else {
-              // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
-              evalMode.toString
-            }
-            val castSupport =
-              CometCast.isSupported(child.dataType, dt, timeZoneId, 
evalModeStr)
-
-            def getIncompatMessage(reason: Option[String]) =
-              "Comet does not guarantee correct results for cast " +
-                s"from ${child.dataType} to $dt " +
-                s"with timezone $timeZoneId and evalMode $evalModeStr" +
-                reason.map(str => s" ($str)").getOrElse("")
-
-            castSupport match {
-              case Compatible(_) =>
-                castToProto(timeZoneId, dt, childExpr, evalModeStr)
-              case Incompatible(reason) =>
-                if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
-                  logWarning(getIncompatMessage(reason))
-                  castToProto(timeZoneId, dt, childExpr, evalModeStr)
-                } else {
-                  withInfo(
-                    expr,
-                    s"${getIncompatMessage(reason)}. To enable all 
incompatible casts, set " +
-                      s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
-                  None
-                }
-              case Unsupported =>
-                withInfo(
-                  expr,
-                  s"Unsupported cast from ${child.dataType} to $dt " +
-                    s"with timezone $timeZoneId and evalMode $evalModeStr")
-                None
-            }
+          val evalModeStr = if (evalMode.isInstanceOf[Boolean]) {
+            // Spark 3.2 & 3.3 has ansiEnabled boolean
+            if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY"
           } else {
-            withInfo(expr, child)
-            None
+            // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
+            evalMode.toString
           }
+          handleCast(child, inputs, dt, timeZoneId, evalModeStr)
 
         case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
           val leftExpr = exprToProtoInternal(left, inputs)
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 50311a9b..ea3355d0 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -886,10 +886,7 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
   private def castTest(input: DataFrame, toType: DataType): Unit = {
 
-    // we do not support the TryCast expression in Spark 3.2 and 3.3
-    // https://github.com/apache/datafusion-comet/issues/374
-    val testTryCast = CometSparkSessionExtensions.isSpark34Plus
-
+    // we now support the TryCast expression in Spark 3.2 and 3.3
     withTempPath { dir =>
       val data = roundtripParquet(input, dir).coalesce(1)
       data.createOrReplaceTempView("t")
@@ -900,11 +897,9 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
         checkSparkAnswerAndOperator(df)
 
         // try_cast() should always return null for invalid inputs
-        if (testTryCast) {
-          val df2 =
-            spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by 
a")
-          checkSparkAnswerAndOperator(df2)
-        }
+        val df2 =
+          spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by 
a")
+        checkSparkAnswerAndOperator(df2)
       }
 
       // with ANSI enabled, we should produce the same exception as Spark
@@ -963,11 +958,10 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
         }
 
         // try_cast() should always return null for invalid inputs
-        if (testTryCast) {
-          val df2 =
-            spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by 
a")
-          checkSparkAnswerAndOperator(df2)
-        }
+        val df2 =
+          spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by 
a")
+        checkSparkAnswerAndOperator(df2)
+
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to