This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new ddf6a6ff feat: Add support for TryCast expression in Spark 3.2 and 3.3
(#416)
ddf6a6ff is described below
commit ddf6a6ffe34eb26f065397034618b9249131121b
Author: Vipul Vaibhaw <[email protected]>
AuthorDate: Thu May 16 05:25:40 2024 +0530
feat: Add support for TryCast expression in Spark 3.2 and 3.3 (#416)
* working on trycast
* code refactor
* compilation fix
* bug fixes and supporting try_Cast
* removing trycast var and comment
* removing issue comment
* adding comments
---
.../org/apache/comet/serde/QueryPlanSerde.scala | 97 +++++++++++++---------
.../scala/org/apache/comet/CometCastSuite.scala | 22 ++---
2 files changed, 64 insertions(+), 55 deletions(-)
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 7238990a..cf7c86a9 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -604,6 +604,52 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
def exprToProtoInternal(expr: Expression, inputs: Seq[Attribute]):
Option[Expr] = {
SQLConf.get
+
+ def handleCast(
+ child: Expression,
+ inputs: Seq[Attribute],
+ dt: DataType,
+ timeZoneId: Option[String],
+ actualEvalModeStr: String): Option[Expr] = {
+
+ val childExpr = exprToProtoInternal(child, inputs)
+ if (childExpr.isDefined) {
+ val castSupport =
+ CometCast.isSupported(child.dataType, dt, timeZoneId,
actualEvalModeStr)
+
+ def getIncompatMessage(reason: Option[String]): String =
+ "Comet does not guarantee correct results for cast " +
+ s"from ${child.dataType} to $dt " +
+ s"with timezone $timeZoneId and evalMode $actualEvalModeStr" +
+ reason.map(str => s" ($str)").getOrElse("")
+
+ castSupport match {
+ case Compatible(_) =>
+ castToProto(timeZoneId, dt, childExpr, actualEvalModeStr)
+ case Incompatible(reason) =>
+ if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
+ logWarning(getIncompatMessage(reason))
+ castToProto(timeZoneId, dt, childExpr, actualEvalModeStr)
+ } else {
+ withInfo(
+ expr,
+ s"${getIncompatMessage(reason)}. To enable all incompatible
casts, set " +
+ s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
+ None
+ }
+ case Unsupported =>
+ withInfo(
+ expr,
+ s"Unsupported cast from ${child.dataType} to $dt " +
+ s"with timezone $timeZoneId and evalMode $actualEvalModeStr")
+ None
+ }
+ } else {
+ withInfo(expr, child)
+ None
+ }
+ }
+
expr match {
case a @ Alias(_, _) =>
val r = exprToProtoInternal(a.child, inputs)
@@ -617,50 +663,19 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
val value = cast.eval()
exprToProtoInternal(Literal(value, dataType), inputs)
+ case UnaryExpression(child) if expr.prettyName == "trycast" =>
+ val timeZoneId = SQLConf.get.sessionLocalTimeZone
+ handleCast(child, inputs, expr.dataType, Some(timeZoneId), "TRY")
+
case Cast(child, dt, timeZoneId, evalMode) =>
- val childExpr = exprToProtoInternal(child, inputs)
- if (childExpr.isDefined) {
- val evalModeStr = if (evalMode.isInstanceOf[Boolean]) {
- // Spark 3.2 & 3.3 has ansiEnabled boolean
- if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY"
- } else {
- // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
- evalMode.toString
- }
- val castSupport =
- CometCast.isSupported(child.dataType, dt, timeZoneId,
evalModeStr)
-
- def getIncompatMessage(reason: Option[String]) =
- "Comet does not guarantee correct results for cast " +
- s"from ${child.dataType} to $dt " +
- s"with timezone $timeZoneId and evalMode $evalModeStr" +
- reason.map(str => s" ($str)").getOrElse("")
-
- castSupport match {
- case Compatible(_) =>
- castToProto(timeZoneId, dt, childExpr, evalModeStr)
- case Incompatible(reason) =>
- if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
- logWarning(getIncompatMessage(reason))
- castToProto(timeZoneId, dt, childExpr, evalModeStr)
- } else {
- withInfo(
- expr,
- s"${getIncompatMessage(reason)}. To enable all
incompatible casts, set " +
- s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
- None
- }
- case Unsupported =>
- withInfo(
- expr,
- s"Unsupported cast from ${child.dataType} to $dt " +
- s"with timezone $timeZoneId and evalMode $evalModeStr")
- None
- }
+ val evalModeStr = if (evalMode.isInstanceOf[Boolean]) {
+ // Spark 3.2 & 3.3 has ansiEnabled boolean
+ if (evalMode.asInstanceOf[Boolean]) "ANSI" else "LEGACY"
} else {
- withInfo(expr, child)
- None
+ // Spark 3.4+ has EvalMode enum with values LEGACY, ANSI, and TRY
+ evalMode.toString
}
+ handleCast(child, inputs, dt, timeZoneId, evalModeStr)
case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
val leftExpr = exprToProtoInternal(left, inputs)
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 50311a9b..ea3355d0 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -886,10 +886,7 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
private def castTest(input: DataFrame, toType: DataType): Unit = {
- // we do not support the TryCast expression in Spark 3.2 and 3.3
- // https://github.com/apache/datafusion-comet/issues/374
- val testTryCast = CometSparkSessionExtensions.isSpark34Plus
-
+ // we now support the TryCast expression in Spark 3.2 and 3.3
withTempPath { dir =>
val data = roundtripParquet(input, dir).coalesce(1)
data.createOrReplaceTempView("t")
@@ -900,11 +897,9 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
checkSparkAnswerAndOperator(df)
// try_cast() should always return null for invalid inputs
- if (testTryCast) {
- val df2 =
- spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by
a")
- checkSparkAnswerAndOperator(df2)
- }
+ val df2 =
+ spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by
a")
+ checkSparkAnswerAndOperator(df2)
}
// with ANSI enabled, we should produce the same exception as Spark
@@ -963,11 +958,10 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
// try_cast() should always return null for invalid inputs
- if (testTryCast) {
- val df2 =
- spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by
a")
- checkSparkAnswerAndOperator(df2)
- }
+ val df2 =
+ spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by
a")
+ checkSparkAnswerAndOperator(df2)
+
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]