This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 311e13e3 chore: Add CometEvalMode enum to replace string literals
(#539)
311e13e3 is described below
commit 311e13e3ec5fe40e10e0f2671e317e1142272af4
Author: Andy Grove <[email protected]>
AuthorDate: Fri Jun 7 15:15:47 2024 -0600
chore: Add CometEvalMode enum to replace string literals (#539)
* Add CometEvalMode enum
* address feedback
---
.../main/scala/org/apache/comet/GenerateDocs.scala | 6 +--
.../org/apache/comet/expressions/CometCast.scala | 4 +-
.../apache/comet/expressions/CometEvalMode.scala} | 29 +++++++----
.../org/apache/comet/serde/QueryPlanSerde.scala | 57 +++++++++-------------
.../org/apache/comet/shims/CometExprShim.scala | 6 ++-
.../org/apache/comet/shims/CometExprShim.scala | 5 +-
.../org/apache/comet/shims/CometExprShim.scala | 15 +++++-
.../org/apache/comet/shims/CometExprShim.scala | 13 +++++
.../scala/org/apache/comet/CometCastSuite.scala | 4 +-
9 files changed, 84 insertions(+), 55 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala
b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala
index a2d5e251..fb86389f 100644
--- a/spark/src/main/scala/org/apache/comet/GenerateDocs.scala
+++ b/spark/src/main/scala/org/apache/comet/GenerateDocs.scala
@@ -25,7 +25,7 @@ import scala.io.Source
import org.apache.spark.sql.catalyst.expressions.Cast
-import org.apache.comet.expressions.{CometCast, Compatible, Incompatible}
+import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible,
Incompatible}
/**
* Utility for generating markdown documentation from the configs.
@@ -72,7 +72,7 @@ object GenerateDocs {
if (Cast.canCast(fromType, toType) && fromType != toType) {
val fromTypeName = fromType.typeName.replace("(10,2)", "")
val toTypeName = toType.typeName.replace("(10,2)", "")
- CometCast.isSupported(fromType, toType, None, "LEGACY") match {
+ CometCast.isSupported(fromType, toType, None,
CometEvalMode.LEGACY) match {
case Compatible(notes) =>
val notesStr = notes.getOrElse("").trim
w.write(s"| $fromTypeName | $toTypeName | $notesStr
|\n".getBytes)
@@ -89,7 +89,7 @@ object GenerateDocs {
if (Cast.canCast(fromType, toType) && fromType != toType) {
val fromTypeName = fromType.typeName.replace("(10,2)", "")
val toTypeName = toType.typeName.replace("(10,2)", "")
- CometCast.isSupported(fromType, toType, None, "LEGACY") match {
+ CometCast.isSupported(fromType, toType, None,
CometEvalMode.LEGACY) match {
case Incompatible(notes) =>
val notesStr = notes.getOrElse("").trim
w.write(s"| $fromTypeName | $toTypeName | $notesStr
|\n".getBytes)
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 11c5a53c..811c61d4 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -55,7 +55,7 @@ object CometCast {
fromType: DataType,
toType: DataType,
timeZoneId: Option[String],
- evalMode: String): SupportLevel = {
+ evalMode: CometEvalMode.Value): SupportLevel = {
if (fromType == toType) {
return Compatible()
@@ -102,7 +102,7 @@ object CometCast {
private def canCastFromString(
toType: DataType,
timeZoneId: Option[String],
- evalMode: String): SupportLevel = {
+ evalMode: CometEvalMode.Value): SupportLevel = {
toType match {
case DataTypes.BooleanType =>
Compatible()
diff --git
a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
b/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala
similarity index 52%
copy from spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
copy to spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala
index f5a578f8..59e9c89a 100644
--- a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometEvalMode.scala
@@ -16,18 +16,27 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.comet.shims
-import org.apache.spark.sql.catalyst.expressions._
+package org.apache.comet.expressions
/**
- * `CometExprShim` acts as a shim for for parsing expressions from different
Spark versions.
+ * We cannot reference Spark's EvalMode directly because the package is
different between Spark
+ * versions, so we copy it here.
+ *
+ * Expression evaluation modes.
+ * - LEGACY: the default evaluation mode, which is compliant to Hive SQL.
+ * - ANSI: a evaluation mode which is compliant to ANSI SQL standard.
+ * - TRY: a evaluation mode for `try_*` functions. It is identical to ANSI
evaluation mode
+ * except for returning null result on errors.
*/
-trait CometExprShim {
- /**
- * Returns a tuple of expressions for the `unhex` function.
- */
- def unhexSerde(unhex: Unhex): (Expression, Expression) = {
- (unhex.child, Literal(false))
- }
+object CometEvalMode extends Enumeration {
+ val LEGACY, ANSI, TRY = Value
+
+ def fromBoolean(ansiEnabled: Boolean): Value = if (ansiEnabled) {
+ ANSI
+ } else {
+ LEGACY
+ }
+
+ def fromString(str: String): CometEvalMode.Value =
CometEvalMode.withName(str)
}
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 448c4ff0..ed3f2fae 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -19,8 +19,6 @@
package org.apache.comet.serde
-import java.util.Locale
-
import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
@@ -45,7 +43,7 @@ import org.apache.spark.unsafe.types.UTF8String
import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.{isCometOperatorEnabled,
isCometScan, isSpark32, isSpark34Plus, withInfo}
-import org.apache.comet.expressions.{CometCast, Compatible, Incompatible,
Unsupported}
+import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible,
Incompatible, Unsupported}
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType =>
ProtoDataType, Expr, ScalarFunc}
import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo,
DecimalInfo, ListInfo, MapInfo, StructInfo}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode =>
CometAggregateMode, JoinType, Operator}
@@ -578,6 +576,15 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
}
}
+ def evalModeToProto(evalMode: CometEvalMode.Value): ExprOuterClass.EvalMode
= {
+ evalMode match {
+ case CometEvalMode.LEGACY => ExprOuterClass.EvalMode.LEGACY
+ case CometEvalMode.TRY => ExprOuterClass.EvalMode.TRY
+ case CometEvalMode.ANSI => ExprOuterClass.EvalMode.ANSI
+ case _ => throw new IllegalStateException(s"Invalid evalMode $evalMode")
+ }
+ }
+
/**
* Convert a Spark expression to protobuf.
*
@@ -590,18 +597,6 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
* @return
* The protobuf representation of the expression, or None if the
expression is not supported
*/
-
- def stringToEvalMode(evalModeStr: String): ExprOuterClass.EvalMode =
- evalModeStr.toUpperCase(Locale.ROOT) match {
- case "LEGACY" => ExprOuterClass.EvalMode.LEGACY
- case "TRY" => ExprOuterClass.EvalMode.TRY
- case "ANSI" => ExprOuterClass.EvalMode.ANSI
- case invalid =>
- throw new IllegalArgumentException(
- s"Invalid eval mode '$invalid' "
- ) // Assuming we want to catch errors strictly
- }
-
def exprToProto(
expr: Expression,
input: Seq[Attribute],
@@ -610,15 +605,14 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
timeZoneId: Option[String],
dt: DataType,
childExpr: Option[Expr],
- evalMode: String): Option[Expr] = {
+ evalMode: CometEvalMode.Value): Option[Expr] = {
val dataType = serializeDataType(dt)
- val evalModeEnum = stringToEvalMode(evalMode) // Convert string to enum
if (childExpr.isDefined && dataType.isDefined) {
val castBuilder = ExprOuterClass.Cast.newBuilder()
castBuilder.setChild(childExpr.get)
castBuilder.setDatatype(dataType.get)
- castBuilder.setEvalMode(evalModeEnum) // Set the enum in protobuf
+ castBuilder.setEvalMode(evalModeToProto(evalMode))
val timeZone = timeZoneId.getOrElse("UTC")
castBuilder.setTimezone(timeZone)
@@ -646,26 +640,26 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
inputs: Seq[Attribute],
dt: DataType,
timeZoneId: Option[String],
- actualEvalModeStr: String): Option[Expr] = {
+ evalMode: CometEvalMode.Value): Option[Expr] = {
val childExpr = exprToProtoInternal(child, inputs)
if (childExpr.isDefined) {
val castSupport =
- CometCast.isSupported(child.dataType, dt, timeZoneId,
actualEvalModeStr)
+ CometCast.isSupported(child.dataType, dt, timeZoneId, evalMode)
def getIncompatMessage(reason: Option[String]): String =
"Comet does not guarantee correct results for cast " +
s"from ${child.dataType} to $dt " +
- s"with timezone $timeZoneId and evalMode $actualEvalModeStr" +
+ s"with timezone $timeZoneId and evalMode $evalMode" +
reason.map(str => s" ($str)").getOrElse("")
castSupport match {
case Compatible(_) =>
- castToProto(timeZoneId, dt, childExpr, actualEvalModeStr)
+ castToProto(timeZoneId, dt, childExpr, evalMode)
case Incompatible(reason) =>
if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
logWarning(getIncompatMessage(reason))
- castToProto(timeZoneId, dt, childExpr, actualEvalModeStr)
+ castToProto(timeZoneId, dt, childExpr, evalMode)
} else {
withInfo(
expr,
@@ -677,7 +671,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
withInfo(
expr,
s"Unsupported cast from ${child.dataType} to $dt " +
- s"with timezone $timeZoneId and evalMode $actualEvalModeStr")
+ s"with timezone $timeZoneId and evalMode $evalMode")
None
}
} else {
@@ -701,17 +695,10 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
case UnaryExpression(child) if expr.prettyName == "trycast" =>
val timeZoneId = SQLConf.get.sessionLocalTimeZone
- handleCast(child, inputs, expr.dataType, Some(timeZoneId), "TRY")
+ handleCast(child, inputs, expr.dataType, Some(timeZoneId),
CometEvalMode.TRY)
- case Cast(child, dt, timeZoneId, evalMode) =>
- val evalModeStr = if (evalMode.isInstanceOf[Boolean]) {
- // Spark 3.2 & 3.3 has ansiEnabled boolean
- if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY"
- } else {
- // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
- evalMode.toString
- }
- handleCast(child, inputs, dt, timeZoneId, evalModeStr)
+ case c @ Cast(child, dt, timeZoneId, _) =>
+ handleCast(child, inputs, dt, timeZoneId, evalMode(c))
case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
val leftExpr = exprToProtoInternal(left, inputs)
@@ -2006,7 +1993,7 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
// TODO: Remove this once we have new DataFusion release which
includes
// the fix: https://github.com/apache/arrow-datafusion/pull/9459
if (childExpr.isDefined) {
- castToProto(None, a.dataType, childExpr, "LEGACY")
+ castToProto(None, a.dataType, childExpr, CometEvalMode.LEGACY)
} else {
withInfo(expr, a.children: _*)
None
diff --git
a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
index f5a578f8..2c6f6ccf 100644
--- a/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.2/org/apache/comet/shims/CometExprShim.scala
@@ -18,6 +18,7 @@
*/
package org.apache.comet.shims
+import org.apache.comet.expressions.CometEvalMode
import org.apache.spark.sql.catalyst.expressions._
/**
@@ -27,7 +28,10 @@ trait CometExprShim {
/**
* Returns a tuple of expressions for the `unhex` function.
*/
- def unhexSerde(unhex: Unhex): (Expression, Expression) = {
+ protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
(unhex.child, Literal(false))
}
+
+ protected def evalMode(c: Cast): CometEvalMode.Value =
CometEvalMode.fromBoolean(c.ansiEnabled)
}
+
diff --git
a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
index f5a578f8..150656c2 100644
--- a/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.3/org/apache/comet/shims/CometExprShim.scala
@@ -18,6 +18,7 @@
*/
package org.apache.comet.shims
+import org.apache.comet.expressions.CometEvalMode
import org.apache.spark.sql.catalyst.expressions._
/**
@@ -27,7 +28,9 @@ trait CometExprShim {
/**
* Returns a tuple of expressions for the `unhex` function.
*/
- def unhexSerde(unhex: Unhex): (Expression, Expression) = {
+ protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
(unhex.child, Literal(false))
}
+
+ protected def evalMode(c: Cast): CometEvalMode.Value =
CometEvalMode.fromBoolean(c.ansiEnabled)
}
diff --git
a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
index 3f2301f0..5f4e3fba 100644
--- a/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-3.4/org/apache/comet/shims/CometExprShim.scala
@@ -18,6 +18,7 @@
*/
package org.apache.comet.shims
+import org.apache.comet.expressions.CometEvalMode
import org.apache.spark.sql.catalyst.expressions._
/**
@@ -27,7 +28,19 @@ trait CometExprShim {
/**
* Returns a tuple of expressions for the `unhex` function.
*/
- def unhexSerde(unhex: Unhex): (Expression, Expression) = {
+ protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
(unhex.child, Literal(unhex.failOnError))
}
+
+ protected def evalMode(c: Cast): CometEvalMode.Value =
+ CometEvalModeUtil.fromSparkEvalMode(c.evalMode)
}
+
+object CometEvalModeUtil {
+ def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value =
evalMode match {
+ case EvalMode.LEGACY => CometEvalMode.LEGACY
+ case EvalMode.TRY => CometEvalMode.TRY
+ case EvalMode.ANSI => CometEvalMode.ANSI
+ }
+}
+
diff --git
a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
index 01f92320..5f4e3fba 100644
--- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
@@ -18,6 +18,7 @@
*/
package org.apache.comet.shims
+import org.apache.comet.expressions.CometEvalMode
import org.apache.spark.sql.catalyst.expressions._
/**
@@ -30,4 +31,16 @@ trait CometExprShim {
protected def unhexSerde(unhex: Unhex): (Expression, Expression) = {
(unhex.child, Literal(unhex.failOnError))
}
+
+ protected def evalMode(c: Cast): CometEvalMode.Value =
+ CometEvalModeUtil.fromSparkEvalMode(c.evalMode)
}
+
+object CometEvalModeUtil {
+ def fromSparkEvalMode(evalMode: EvalMode.Value): CometEvalMode.Value =
evalMode match {
+ case EvalMode.LEGACY => CometEvalMode.LEGACY
+ case EvalMode.TRY => CometEvalMode.TRY
+ case EvalMode.ANSI => CometEvalMode.ANSI
+ }
+}
+
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index fd221896..25343f93 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType}
-import org.apache.comet.expressions.{CometCast, Compatible}
+import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible}
class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
@@ -76,7 +76,7 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
} else {
val testIgnored =
tags.get(expectedTestName).exists(s =>
s.contains("org.scalatest.Ignore"))
- CometCast.isSupported(fromType, toType, None, "LEGACY") match {
+ CometCast.isSupported(fromType, toType, None,
CometEvalMode.LEGACY) match {
case Compatible(_) =>
if (testIgnored) {
fail(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]