This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 311e13e3 chore: Add CometEvalMode enum to replace string literals 
(#539)
311e13e3 is described below

commit 311e13e3ec5fe40e10e0f2671e317e1142272af4
Author: Andy Grove <[email protected]>
AuthorDate: Fri Jun 7 15:15:47 2024 -0600

    chore: Add CometEvalMode enum to replace string literals (#539)
    
    * Add CometEvalMode enum
    
    * address feedback
---
 .../main/scala/org/apache/comet/GenerateDocs.scala |  6 +--
 .../org/apache/comet/expressions/CometCast.scala   |  4 +-
 .../apache/comet/expressions/CometEvalMode.scala}  | 29 +++++++----
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 57 +++++++++-------------
 .../org/apache/comet/shims/CometExprShim.scala     |  6 ++-
 .../org/apache/comet/shims/CometExprShim.scala     |  5 +-
 .../org/apache/comet/shims/CometExprShim.scala     | 15 +++++-
 .../org/apache/comet/shims/CometExprShim.scala     | 13 +++++
 .../scala/org/apache/comet/CometCastSuite.scala    |  4 +-
 9 files changed, 84 insertions(+), 55 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala 
b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala
index a2d5e251..fb86389f 100644
--- a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala
+++ b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala
@@ -25,7 +25,7 @@ import scala.io.Source
 
 import org.apache.spark.sql.catalyst.expressions.Cast
 
-import org.apache.comet.expressions.{CometCast, Compatible, Incompatible}
+import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, 
Incompatible}
 
 /**
  * Utility for generating markdown documentation from the configs.
@@ -72,7 +72,7 @@ object GenerateDocs {
             if (Cast.canCast(fromType, toType) && fromType != toType) {
               val fromTypeName = fromType.typeName.replace("(10,2)", "")
               val toTypeName = toType.typeName.replace("(10,2)", "")
-              CometCast.isSupported(fromType, toType, None, "LEGACY") match {
+              CometCast.isSupported(fromType, toType, None, 
CometEvalMode.LEGACY) match {
                 case Compatible(notes) =>
                   val notesStr = notes.getOrElse("").trim
                   w.write(s"| $fromTypeName | $toTypeName | $notesStr 
|\n".getBytes)
@@ -89,7 +89,7 @@ object GenerateDocs {
             if (Cast.canCast(fromType, toType) && fromType != toType) {
               val fromTypeName = fromType.typeName.replace("(10,2)", "")
               val toTypeName = toType.typeName.replace("(10,2)", "")
-              CometCast.isSupported(fromType, toType, None, "LEGACY") match {
+              CometCast.isSupported(fromType, toType, None, 
CometEvalMode.LEGACY) match {
                 case Incompatible(notes) =>
                   val notesStr = notes.getOrElse("").trim
                   w.write(s"| $fromTypeName | $toTypeName  | $notesStr 
|\n".getBytes)
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala 
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 11c5a53c..811c61d4 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -55,7 +55,7 @@ object CometCast {
       fromType: DataType,
       toType: DataType,
       timeZoneId: Option[String],
-      evalMode: String): SupportLevel = {
+      evalMode: CometEvalMode.Value): SupportLevel = {
 
     if (fromType == toType) {
       return Compatible()
@@ -102,7 +102,7 @@ object CometCast {
   private def canCastFromString(
       toType: DataType,
       timeZoneId: Option[String],
-      evalMode: String): SupportLevel = {
+      evalMode: CometEvalMode.Value): SupportLevel = {
     toType match {
       case DataTypes.BooleanType =>
         Compatible()
diff --git 
a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala
similarity index 52%
copy from spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
copy to spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala
index f5a578f8..59e9c89a 100644
--- a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala
@@ -16,18 +16,27 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-package org.apache.comet.shims
 
-import org.apache.spark.sql.catalyst.expressions._
+package org.apache.comet.expressions
 
 /**
- * `CometExprShim` acts as a shim for for parsing expressions from different 
Spark versions.
+ * We cannot reference Spark's EvalMode directly because the package is 
different between Spark
+ * versions, so we copy it here.
+ *
+ * Expression evaluation modes.
+ *   - LEGACY: the default evaluation mode, which is compliant to Hive SQL.
+ *   - ANSI: a evaluation mode which is compliant to ANSI SQL standard.
+ *   - TRY: a evaluation mode for `try_*` functions. It is identical to ANSI 
evaluation mode
+ *     except for returning null result on errors.
  */
-trait CometExprShim {
-    /**
-     * Returns a tuple of expressions for the `unhex` function.
-     */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
-        (unhex.child, Literal(false))
-    }
+object CometEvalMode extends Enumeration {
+  val LEGACY, ANSI, TRY = Value
+
+  def fromBoolean(ansiEnabled: Boolean): Value = if (ansiEnabled) {
+    ANSI
+  } else {
+    LEGACY
+  }
+
+  def fromString(str: String): CometEvalMode.Value = 
CometEvalMode.withName(str)
 }
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 448c4ff0..ed3f2fae 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -19,8 +19,6 @@
 
 package org.apache.comet.serde
 
-import java.util.Locale
-
 import scala.collection.JavaConverters._
 
 import org.apache.spark.internal.Logging
@@ -45,7 +43,7 @@ import org.apache.spark.unsafe.types.UTF8String
 
 import org.apache.comet.CometConf
 import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled, 
isCometScan, isSpark32, isSpark34Plus, withInfo}
-import org.apache.comet.expressions.{CometCast, Compatible, Incompatible, 
Unsupported}
+import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, 
Incompatible, Unsupported}
 import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => 
ProtoDataType, Expr, ScalarFunc}
 import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, 
DecimalInfo, ListInfo, MapInfo, StructInfo}
 import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => 
CometAggregateMode, JoinType, Operator}
@@ -578,6 +576,15 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
     }
   }
 
+  def evalModeToProto(evalMode: CometEvalMode.Value): ExprOuterClass.EvalMode 
= {
+    evalMode match {
+      case CometEvalMode.LEGACY => ExprOuterClass.EvalMode.LEGACY
+      case CometEvalMode.TRY => ExprOuterClass.EvalMode.TRY
+      case CometEvalMode.ANSI => ExprOuterClass.EvalMode.ANSI
+      case _ => throw new IllegalStateException(s"Invalid evalMode $evalMode")
+    }
+  }
+
   /**
    * Convert a Spark expression to protobuf.
    *
@@ -590,18 +597,6 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
    * @return
    *   The protobuf representation of the expression, or None if the 
expression is not supported
    */
-
-  def stringToEvalMode(evalModeStr: String): ExprOuterClass.EvalMode =
-    evalModeStr.toUpperCase(Locale.ROOT) match {
-      case "LEGACY" => ExprOuterClass.EvalMode.LEGACY
-      case "TRY" => ExprOuterClass.EvalMode.TRY
-      case "ANSI" => ExprOuterClass.EvalMode.ANSI
-      case invalid =>
-        throw new IllegalArgumentException(
-          s"Invalid eval mode '$invalid' "
-        ) // Assuming we want to catch errors strictly
-    }
-
   def exprToProto(
       expr: Expression,
       input: Seq[Attribute],
@@ -610,15 +605,14 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
         timeZoneId: Option[String],
         dt: DataType,
         childExpr: Option[Expr],
-        evalMode: String): Option[Expr] = {
+        evalMode: CometEvalMode.Value): Option[Expr] = {
       val dataType = serializeDataType(dt)
-      val evalModeEnum = stringToEvalMode(evalMode) // Convert string to enum
 
       if (childExpr.isDefined && dataType.isDefined) {
         val castBuilder = ExprOuterClass.Cast.newBuilder()
         castBuilder.setChild(childExpr.get)
         castBuilder.setDatatype(dataType.get)
-        castBuilder.setEvalMode(evalModeEnum) // Set the enum in protobuf
+        castBuilder.setEvalMode(evalModeToProto(evalMode))
 
         val timeZone = timeZoneId.getOrElse("UTC")
         castBuilder.setTimezone(timeZone)
@@ -646,26 +640,26 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           inputs: Seq[Attribute],
           dt: DataType,
           timeZoneId: Option[String],
-          actualEvalModeStr: String): Option[Expr] = {
+          evalMode: CometEvalMode.Value): Option[Expr] = {
 
         val childExpr = exprToProtoInternal(child, inputs)
         if (childExpr.isDefined) {
           val castSupport =
-            CometCast.isSupported(child.dataType, dt, timeZoneId, 
actualEvalModeStr)
+            CometCast.isSupported(child.dataType, dt, timeZoneId, evalMode)
 
           def getIncompatMessage(reason: Option[String]): String =
             "Comet does not guarantee correct results for cast " +
               s"from ${child.dataType} to $dt " +
-              s"with timezone $timeZoneId and evalMode $actualEvalModeStr" +
+              s"with timezone $timeZoneId and evalMode $evalMode" +
               reason.map(str => s" ($str)").getOrElse("")
 
           castSupport match {
             case Compatible(_) =>
-              castToProto(timeZoneId, dt, childExpr, actualEvalModeStr)
+              castToProto(timeZoneId, dt, childExpr, evalMode)
             case Incompatible(reason) =>
               if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
                 logWarning(getIncompatMessage(reason))
-                castToProto(timeZoneId, dt, childExpr, actualEvalModeStr)
+                castToProto(timeZoneId, dt, childExpr, evalMode)
               } else {
                 withInfo(
                   expr,
@@ -677,7 +671,7 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
               withInfo(
                 expr,
                 s"Unsupported cast from ${child.dataType} to $dt " +
-                  s"with timezone $timeZoneId and evalMode $actualEvalModeStr")
+                  s"with timezone $timeZoneId and evalMode $evalMode")
               None
           }
         } else {
@@ -701,17 +695,10 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
 
         case UnaryExpression(child) if expr.prettyName == "trycast" =>
           val timeZoneId = SQLConf.get.sessionLocalTimeZone
-          handleCast(child, inputs, expr.dataType, Some(timeZoneId), "TRY")
+          handleCast(child, inputs, expr.dataType, Some(timeZoneId), 
CometEvalMode.TRY)
 
-        case Cast(child, dt, timeZoneId, evalMode) =>
-          val evalModeStr = if (evalMode.isInstanceOf[Boolean]) {
-            // Spark 3.2 & 3.3 has ansiEnabled boolean
-            if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY"
-          } else {
-            // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
-            evalMode.toString
-          }
-          handleCast(child, inputs, dt, timeZoneId, evalModeStr)
+        case c @ Cast(child, dt, timeZoneId, _) =>
+          handleCast(child, inputs, dt, timeZoneId, evalMode(c))
 
         case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
           val leftExpr = exprToProtoInternal(left, inputs)
@@ -2006,7 +1993,7 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
           // TODO: Remove this once we have new DataFusion release which 
includes
           // the fix: https://github.com/apache/arrow-datafusion/pull/9459
           if (childExpr.isDefined) {
-            castToProto(None, a.dataType, childExpr, "LEGACY")
+            castToProto(None, a.dataType, childExpr, CometEvalMode.LEGACY)
           } else {
             withInfo(expr, a.children: _*)
             None
diff --git 
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
index f5a578f8..2c6f6ccf 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
@@ -18,6 +18,7 @@
  */
 package org.apache.comet.shims
 
+import org.apache.comet.expressions.CometEvalMode
 import org.apache.spark.sql.catalyst.expressions._
 
 /**
@@ -27,7 +28,10 @@ trait CometExprShim {
     /**
      * Returns a tuple of expressions for the `unhex` function.
      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
+    protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
         (unhex.child, Literal(false))
     }
+
+    protected def evalMode(c: Cast): CometEvalMode.Value = 
CometEvalMode.fromBoolean(c.ansiEnabled)
 }
+
diff --git 
a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
index f5a578f8..150656c2 100644
--- a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
@@ -18,6 +18,7 @@
  */
 package org.apache.comet.shims
 
+import org.apache.comet.expressions.CometEvalMode
 import org.apache.spark.sql.catalyst.expressions._
 
 /**
@@ -27,7 +28,9 @@ trait CometExprShim {
     /**
      * Returns a tuple of expressions for the `unhex` function.
      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
+    protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
         (unhex.child, Literal(false))
     }
+
+    protected def evalMode(c: Cast): CometEvalMode.Value = 
CometEvalMode.fromBoolean(c.ansiEnabled)
 }
diff --git 
a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
index 3f2301f0..5f4e3fba 100644
--- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
@@ -18,6 +18,7 @@
  */
 package org.apache.comet.shims
 
+import org.apache.comet.expressions.CometEvalMode
 import org.apache.spark.sql.catalyst.expressions._
 
 /**
@@ -27,7 +28,19 @@ trait CometExprShim {
     /**
      * Returns a tuple of expressions for the `unhex` function.
      */
-    def unhexSerde(unhex: Unhex): (Expression, Expression) = {
+    protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
         (unhex.child, Literal(unhex.failOnError))
     }
+
+    protected def evalMode(c: Cast): CometEvalMode.Value =
+        CometEvalModeUtil.fromSparkEvalMode(c.evalMode)
 }
+
+object CometEvalModeUtil {
+    def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = 
evalMode match {
+        case EvalMode.LEGACY => CometEvalMode.LEGACY
+        case EvalMode.TRY => CometEvalMode.TRY
+        case EvalMode.ANSI => CometEvalMode.ANSI
+    }
+}
+
diff --git 
a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala 
b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
index 01f92320..5f4e3fba 100644
--- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
@@ -18,6 +18,7 @@
  */
 package org.apache.comet.shims
 
+import org.apache.comet.expressions.CometEvalMode
 import org.apache.spark.sql.catalyst.expressions._
 
 /**
@@ -30,4 +31,16 @@ trait CometExprShim {
     protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
         (unhex.child, Literal(unhex.failOnError))
     }
+
+    protected def evalMode(c: Cast): CometEvalMode.Value =
+        CometEvalModeUtil.fromSparkEvalMode(c.evalMode)
 }
+
+object CometEvalModeUtil {
+    def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value = 
evalMode match {
+        case EvalMode.LEGACY => CometEvalMode.LEGACY
+        case EvalMode.TRY => CometEvalMode.TRY
+        case EvalMode.ANSI => CometEvalMode.ANSI
+    }
+}
+
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index fd221896..25343f93 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType}
 
-import org.apache.comet.expressions.{CometCast, Compatible}
+import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible}
 
 class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
 
@@ -76,7 +76,7 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
             } else {
               val testIgnored =
                 tags.get(expectedTestName).exists(s => 
s.contains("org.scalatest.Ignore"))
-              CometCast.isSupported(fromType, toType, None, "LEGACY") match {
+              CometCast.isSupported(fromType, toType, None, 
CometEvalMode.LEGACY) match {
                 case Compatible(_) =>
                   if (testIgnored) {
                     fail(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to