This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3951e33  [SPARK-34881][SQL] New SQL Function: TRY_CAST
3951e33 is described below

commit 3951e3371a83578a81474ed99fb50d59f27aac62
Author: Gengliang Wang <ltn...@gmail.com>
AuthorDate: Wed Mar 31 20:47:04 2021 +0800

    [SPARK-34881][SQL] New SQL Function: TRY_CAST
    
    ### What changes were proposed in this pull request?
    
    Add a new SQL function `try_cast`.
    `try_cast` is identical to  `AnsiCast` (or `Cast` when 
`spark.sql.ansi.enabled` is true), except it returns NULL instead of raising an 
error.
    This expression has one major difference from `cast` with 
`spark.sql.ansi.enabled` as true: when the source value can't be stored in the 
target integral(Byte/Short/Int/Long) type, `try_cast` returns null instead of 
returning the low order bytes of the source value.
    Note that the result of `try_cast` is not affected by the configuration 
`spark.sql.ansi.enabled`.
    
    This is learned from Google BigQuery and Snowflake:
    https://docs.snowflake.com/en/sql-reference/functions/try_cast.html
    
https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-and-operators#safe_casting
    
    ### Why are the changes needed?
    
    This is an useful for the following scenarios:
    1. When ANSI mode is on, users can choose `try_cast` an alternative way to 
run SQL without errors for certain operations.
    2. When ANSI mode is off, users can use `try_cast` to get a more reasonable 
result for casting a value to an integral type: when an overflow error happens, 
`try_cast` returns null while `cast` returns the low order bytes of the source 
value.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, adding a new function `try_cast`
    
    ### How was this patch tested?
    
    Unit tests.
    
    Closes #31982 from gengliangwang/tryCast.
    
    Authored-by: Gengliang Wang <ltn...@gmail.com>
    Signed-off-by: Gengliang Wang <ltn...@gmail.com>
---
 docs/sql-ref-ansi-compliance.md                    |   1 +
 .../apache/spark/sql/catalyst/parser/SqlBase.g4    |   5 +-
 .../spark/sql/catalyst/expressions/Cast.scala      |  27 +--
 .../spark/sql/catalyst/expressions/TryCast.scala   |  85 ++++++++
 .../spark/sql/catalyst/parser/AstBuilder.scala     |   8 +-
 .../spark/sql/catalyst/expressions/CastSuite.scala |  52 +++--
 .../sql/catalyst/expressions/TryCastSuite.scala    |  51 +++++
 .../test/resources/sql-tests/inputs/try_cast.sql   |  54 +++++
 .../resources/sql-tests/results/try_cast.sql.out   | 234 +++++++++++++++++++++
 9 files changed, 486 insertions(+), 31 deletions(-)

diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index f4fd712..70a1fa3 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -434,6 +434,7 @@ Below is a list of all the keywords in Spark SQL.
 |TRIM|non-reserved|non-reserved|non-reserved|
 |TRUE|non-reserved|non-reserved|reserved|
 |TRUNCATE|non-reserved|non-reserved|reserved|
+|TRY_CAST|non-reserved|non-reserved|non-reserved|
 |TYPE|non-reserved|non-reserved|non-reserved|
 |UNARCHIVE|non-reserved|non-reserved|non-reserved|
 |UNBOUNDED|non-reserved|non-reserved|non-reserved|
diff --git 
a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index e694eda..55ba375 100644
--- 
a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ 
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -805,7 +805,7 @@ primaryExpression
     : name=(CURRENT_DATE | CURRENT_TIMESTAMP)                                  
                #currentDatetime
     | CASE whenClause+ (ELSE elseExpression=expression)? END                   
                #searchedCase
     | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END  
                #simpleCase
-    | CAST '(' expression AS dataType ')'                                      
                #cast
+    | name=(CAST | TRY_CAST) '(' expression AS dataType ')'                    
                #cast
     | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? 
')'             #struct
     | FIRST '(' expression (IGNORE NULLS)? ')'                                 
                #first
     | LAST '(' expression (IGNORE NULLS)? ')'                                  
                #last
@@ -1199,6 +1199,7 @@ ansiNonReserved
     | TRIM
     | TRUE
     | TRUNCATE
+    | TRY_CAST
     | TYPE
     | UNARCHIVE
     | UNBOUNDED
@@ -1461,6 +1462,7 @@ nonReserved
     | TRIM
     | TRUE
     | TRUNCATE
+    | TRY_CAST
     | TYPE
     | UNARCHIVE
     | UNBOUNDED
@@ -1720,6 +1722,7 @@ TRANSFORM: 'TRANSFORM';
 TRIM: 'TRIM';
 TRUE: 'TRUE';
 TRUNCATE: 'TRUNCATE';
+TRY_CAST: 'TRY_CAST';
 TYPE: 'TYPE';
 UNARCHIVE: 'UNARCHIVE';
 UNBOUNDED: 'UNBOUNDED';
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a696c40..6b18563 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -298,7 +298,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
   def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType)
 
   // [[func]] assumes the input is no longer null because eval already does 
the null check.
-  @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = 
func(a.asInstanceOf[T])
+  @inline protected[this] def buildCast[T](a: Any, func: T => Any): Any = 
func(a.asInstanceOf[T])
 
   private lazy val dateFormatter = DateFormatter(zoneId)
   private lazy val timestampFormatter = 
TimestampFormatter.getFractionFormatter(zoneId)
@@ -810,7 +810,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
     })
   }
 
-  private[this] def cast(from: DataType, to: DataType): Any => Any = {
+  protected[this] def cast(from: DataType, to: DataType): Any => Any = {
     // If the cast does not change the structure, then we don't really need to 
cast anything.
     // We can return what the children return. Same thing should happen in the 
codegen path.
     if (DataType.equalsStructurally(from, to)) {
@@ -849,7 +849,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
     }
   }
 
-  private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
+  protected[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
 
   protected override def nullSafeEval(input: Any): Any = cast(input)
 
@@ -873,7 +873,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
 
   // The function arguments are: `input`, `result` and `resultIsNull`. We 
don't need `inputIsNull`
   // in parameter list, because the returned code will be put in null safe 
evaluation region.
-  private[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => Block
+  protected[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => 
Block
 
   private[this] def nullSafeCastFunction(
       from: DataType,
@@ -908,7 +908,7 @@ abstract class CastBase extends UnaryExpression with 
TimeZoneAwareExpression wit
 
   // Since we need to cast input expressions recursively inside ComplexTypes, 
such as Map's
   // Key and Value, Struct's field, we need to name out all the variable names 
involved in a cast.
-  private[this] def castCode(ctx: CodegenContext, input: ExprValue, 
inputIsNull: ExprValue,
+  protected[this] def castCode(ctx: CodegenContext, input: ExprValue, 
inputIsNull: ExprValue,
     result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: 
CastFunction): Block = {
     val javaType = JavaCode.javaType(resultType)
     code"""
@@ -1795,7 +1795,8 @@ case class Cast(child: Expression, dataType: DataType, 
timeZoneId: Option[String
   }
 
   override def typeCheckFailureMessage: String = if (ansiEnabled) {
-    AnsiCast.typeCheckFailureMessage(child.dataType, dataType, 
SQLConf.ANSI_ENABLED.key, "false")
+    AnsiCast.typeCheckFailureMessage(child.dataType, dataType,
+      Some(SQLConf.ANSI_ENABLED.key), Some("false"))
   } else {
     s"cannot cast ${child.dataType.catalogString} to ${dataType.catalogString}"
   }
@@ -1823,8 +1824,10 @@ case class AnsiCast(child: Expression, dataType: 
DataType, timeZoneId: Option[St
   // If there are more scenarios for this expression, we should update the 
error message on type
   // check failure.
   override def typeCheckFailureMessage: String =
-    AnsiCast.typeCheckFailureMessage(child.dataType, dataType,
-      SQLConf.STORE_ASSIGNMENT_POLICY.key, 
SQLConf.StoreAssignmentPolicy.LEGACY.toString)
+    AnsiCast.typeCheckFailureMessage(child.dataType,
+      dataType,
+      Some(SQLConf.STORE_ASSIGNMENT_POLICY.key),
+      Some(SQLConf.StoreAssignmentPolicy.LEGACY.toString))
 
 }
 
@@ -1940,8 +1943,8 @@ object AnsiCast {
   def typeCheckFailureMessage(
       from: DataType,
       to: DataType,
-      fallbackConfKey: String,
-      fallbackConfValue: String): String =
+      fallbackConfKey: Option[String],
+      fallbackConfValue: Option[String]): String =
     (from, to) match {
       case (_: NumericType, TimestampType) =>
         suggestionOnConversionFunctions(from, to,
@@ -1957,10 +1960,10 @@ object AnsiCast {
         suggestionOnConversionFunctions(from, to, "function UNIX_DATE")
 
       // scalastyle:off line.size.limit
-      case _ if Cast.canCast(from, to) =>
+      case _ if fallbackConfKey.isDefined && fallbackConfValue.isDefined && 
Cast.canCast(from, to) =>
         s"""
            | cannot cast ${from.catalogString} to ${to.catalogString} with 
ANSI mode on.
-           | If you have to cast ${from.catalogString} to ${to.catalogString}, 
you can set $fallbackConfKey as $fallbackConfValue.
+           | If you have to cast ${from.catalogString} to ${to.catalogString}, 
you can set ${fallbackConfKey.get} as ${fallbackConfValue.get}.
            |""".stripMargin
       // scalastyle:on line.size.limit
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala
new file mode 100644
index 0000000..aba76db
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryCast.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.types.DataType
+
+/**
+ * A special version of [[AnsiCast]]. It performs the same operation (i.e. 
converts a value of
+ * one data type into another data type), but returns a NULL value instead of 
raising an error
+ * when the conversion can not be performed.
+ *
+ * When cast from/to timezone related types, we need timeZoneId, which will be 
resolved with
+ * session local timezone by an analyzer [[ResolveTimeZone]].
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data 
type `type`. " +
+    "This expression is identical to CAST with configuration 
`spark.sql.ansi.enabled` as " +
+    "true, except it returns NULL instead of raising an error. Note that the 
behavior of this " +
+    "expression doesn't depend on configuration `spark.sql.ansi.enabled`.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_('10' as int);
+       10
+      > SELECT _FUNC_(1234567890123L as int);
+       null
+  """,
+  since = "3.2.0",
+  group = "conversion_funcs")
+case class TryCast(child: Expression, dataType: DataType, timeZoneId: 
Option[String] = None)
+  extends CastBase {
+  override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+    copy(timeZoneId = Option(timeZoneId))
+
+  // Here we force `ansiEnabled` as true so that we can reuse the evaluation 
code branches which
+  // throw exceptions on conversion failures.
+  override protected val ansiEnabled: Boolean = true
+
+  override def nullable: Boolean = true
+
+  override def canCast(from: DataType, to: DataType): Boolean = 
AnsiCast.canCast(from, to)
+
+  override def cast(from: DataType, to: DataType): Any => Any = (input: Any) =>
+    try {
+      super.cast(from, to)(input)
+    } catch {
+      case _: Exception =>
+        null
+    }
+
+  override def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: 
ExprValue,
+    result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: 
CastFunction): Block = {
+    val javaType = JavaCode.javaType(resultType)
+    code"""
+      boolean $resultIsNull = $inputIsNull;
+      $javaType $result = ${CodeGenerator.defaultValue(resultType)};
+      if (!$inputIsNull) {
+        try {
+          ${cast(input, result, resultIsNull)}
+        } catch (Exception e) {
+          $resultIsNull = true;
+        }
+      }
+    """
+  }
+
+  override def typeCheckFailureMessage: String =
+    AnsiCast.typeCheckFailureMessage(child.dataType, dataType, None, None)
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index c49fbab..dc87398 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1654,7 +1654,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with 
SQLConfHelper with Logg
   override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
     val rawDataType = typedVisit[DataType](ctx.dataType())
     val dataType = 
CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType)
-    val cast = Cast(expression(ctx.expression), dataType)
+    val cast = ctx.name.getType match {
+      case SqlBaseParser.CAST =>
+        Cast(expression(ctx.expression), dataType)
+
+      case SqlBaseParser.TRY_CAST =>
+        TryCast(expression(ctx.expression), dataType)
+    }
     cast.setTagValue(Cast.USER_SPECIFIED_CAST, true)
     cast
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index e095910..3a79e8d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -52,6 +52,8 @@ abstract class CastSuiteBase extends SparkFunSuite with 
ExpressionEvalHelper {
     checkEvaluation(cast(Literal.create(null, from), to, UTC_OPT), null)
   }
 
+  protected def isAlwaysNullable: Boolean = false
+
   test("null cast") {
     import DataTypeTestUtils._
 
@@ -252,8 +254,8 @@ abstract class CastSuiteBase extends SparkFunSuite with 
ExpressionEvalHelper {
   }
 
   test("cast from string") {
-    assert(cast("abcdef", StringType).nullable === false)
-    assert(cast("abcdef", BinaryType).nullable === false)
+    assert(cast("abcdef", StringType).nullable === isAlwaysNullable)
+    assert(cast("abcdef", BinaryType).nullable === isAlwaysNullable)
     assert(cast("abcdef", BooleanType).nullable)
     assert(cast("abcdef", TimestampType).nullable)
     assert(cast("abcdef", LongType).nullable)
@@ -910,9 +912,11 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
 
     if (optionalExpectedMsg.isDefined) {
       assert(message.contains(optionalExpectedMsg.get))
-    } else {
+    } else if (setConfigurationHint.nonEmpty) {
       assert(message.contains("with ANSI mode on"))
       assert(message.contains(setConfigurationHint))
+    } else {
+      assert("cannot cast [a-zA-Z]+ to 
[a-zA-Z]+".r.findFirstIn(message).isDefined)
     }
   }
 
@@ -965,11 +969,6 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
   test("cast from invalid string to numeric should throw 
NumberFormatException") {
     // cast to IntegerType
     Seq(IntegerType, ShortType, ByteType, LongType).foreach { dataType =>
-      val array = Literal.create(Seq("123", "true", "f", null),
-        ArrayType(StringType, containsNull = true))
-      checkExceptionInExpression[NumberFormatException](
-        cast(array, ArrayType(dataType, containsNull = true)),
-        "invalid input syntax for type numeric: true")
       checkExceptionInExpression[NumberFormatException](
         cast("string", dataType), "invalid input syntax for type numeric: 
string")
       checkExceptionInExpression[NumberFormatException](
@@ -990,6 +989,25 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
     }
   }
 
+  protected def checkCastToNumericError(l: Literal, to: DataType, 
tryCastResult: Any): Unit = {
+    checkExceptionInExpression[NumberFormatException](
+      cast(l, to), "invalid input syntax for type numeric: true")
+  }
+
+  test("cast from invalid string array to numeric array should throw 
NumberFormatException") {
+    val array = Literal.create(Seq("123", "true", "f", null),
+      ArrayType(StringType, containsNull = true))
+
+    checkCastToNumericError(array, ArrayType(ByteType, containsNull = true),
+      Seq(123.toByte, null, null, null))
+    checkCastToNumericError(array, ArrayType(ShortType, containsNull = true),
+      Seq(123.toShort, null, null, null))
+    checkCastToNumericError(array, ArrayType(IntegerType, containsNull = true),
+      Seq(123, null, null, null))
+    checkCastToNumericError(array, ArrayType(LongType, containsNull = true),
+      Seq(123L, null, null, null))
+  }
+
   test("Fast fail for cast string type to decimal type in ansi mode") {
     checkEvaluation(cast("12345678901234567890123456789012345678", 
DecimalType(38, 0)),
       Decimal("12345678901234567890123456789012345678"))
@@ -1023,14 +1041,14 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
       "invalid input syntax for type numeric")
   }
 
-  protected def checkCastToBooleanError(l: Literal, to: DataType): Unit = {
+  protected def checkCastToBooleanError(l: Literal, to: DataType, 
tryCastResult: Any): Unit = {
     checkExceptionInExpression[UnsupportedOperationException](
       cast(l, to), s"invalid input syntax for type boolean")
   }
 
   test("ANSI mode: cast string to boolean with parse error") {
-    checkCastToBooleanError(Literal("abc"), BooleanType)
-    checkCastToBooleanError(Literal(""), BooleanType)
+    checkCastToBooleanError(Literal("abc"), BooleanType, null)
+    checkCastToBooleanError(Literal(""), BooleanType, null)
   }
 
   test("cast from array II") {
@@ -1043,14 +1061,14 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
       val to: DataType = ArrayType(BooleanType, containsNull = true)
       val ret = cast(array, to)
       assert(ret.resolved)
-      checkCastToBooleanError(array, to)
+      checkCastToBooleanError(array, to, Seq(null, true, false, null))
     }
 
     {
       val to: DataType = ArrayType(BooleanType, containsNull = true)
       val ret = cast(array_notNull, to)
       assert(ret.resolved)
-      checkCastToBooleanError(array_notNull, to)
+      checkCastToBooleanError(array_notNull, to, Seq(null, true, false))
     }
   }
 
@@ -1068,14 +1086,14 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
       val to: DataType = MapType(StringType, BooleanType, valueContainsNull = 
true)
       val ret = cast(map, to)
       assert(ret.resolved)
-      checkCastToBooleanError(map, to)
+      checkCastToBooleanError(map, to, Map("a" -> null, "b" -> true, "c" -> 
false, "d" -> null))
     }
 
     {
       val to: DataType = MapType(StringType, BooleanType, valueContainsNull = 
true)
       val ret = cast(map_notNull, to)
       assert(ret.resolved)
-      checkCastToBooleanError(map_notNull, to)
+      checkCastToBooleanError(map_notNull, to, Map("a" -> null, "b" -> true, 
"c" -> false))
     }
   }
 
@@ -1117,7 +1135,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
         StructField("d", BooleanType, nullable = true)))
       val ret = cast(struct, to)
       assert(ret.resolved)
-      checkCastToBooleanError(struct, to)
+      checkCastToBooleanError(struct, to, InternalRow(null, true, false, null))
     }
 
     {
@@ -1127,7 +1145,7 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
         StructField("c", BooleanType, nullable = true)))
       val ret = cast(struct_notNull, to)
       assert(ret.resolved)
-      checkCastToBooleanError(struct_notNull, to)
+      checkCastToBooleanError(struct_notNull, to, InternalRow(null, true, 
false))
     }
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
new file mode 100644
index 0000000..bcf8a22
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryCastSuite.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.DataType
+
+class TryCastSuite extends AnsiCastSuiteBase {
+  override protected def cast(v: Any, targetType: DataType, timeZoneId: 
Option[String]) = {
+    v match {
+      case lit: Expression => TryCast(lit, targetType, timeZoneId)
+      case _ => TryCast(Literal(v), targetType, timeZoneId)
+    }
+  }
+
+  override def isAlwaysNullable: Boolean = true
+
+  override protected def setConfigurationHint: String = ""
+
+  override def checkExceptionInExpression[T <: Throwable : ClassTag](
+      expression: => Expression,
+      inputRow: InternalRow,
+      expectedErrMsg: String): Unit = {
+    checkEvaluation(expression, null, inputRow)
+  }
+
+  override def checkCastToBooleanError(l: Literal, to: DataType, 
tryCastResult: Any): Unit = {
+    checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value))
+  }
+
+  override def checkCastToNumericError(l: Literal, to: DataType, 
tryCastResult: Any): Unit = {
+    checkEvaluation(cast(l, to), tryCastResult, InternalRow(l.value))
+  }
+}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_cast.sql 
b/sql/core/src/test/resources/sql-tests/inputs/try_cast.sql
new file mode 100644
index 0000000..2d58484
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/try_cast.sql
@@ -0,0 +1,54 @@
+-- TRY_CAST string representing a valid fractional number to integral should 
truncate the number
+SELECT TRY_CAST('1.23' AS int);
+SELECT TRY_CAST('1.23' AS long);
+SELECT TRY_CAST('-4.56' AS int);
+SELECT TRY_CAST('-4.56' AS long);
+
+-- TRY_CAST string which are not numbers to integral should return null
+SELECT TRY_CAST('abc' AS int);
+SELECT TRY_CAST('abc' AS long);
+
+-- TRY_CAST empty string to integral should return null
+SELECT TRY_CAST('' AS int);
+SELECT TRY_CAST('' AS long);
+
+-- TRY_CAST null to integral should return null
+SELECT TRY_CAST(NULL AS int);
+SELECT TRY_CAST(NULL AS long);
+
+-- TRY_CAST invalid decimal string to integral should return null
+SELECT TRY_CAST('123.a' AS int);
+SELECT TRY_CAST('123.a' AS long);
+
+-- '-2147483648' is the smallest int value
+SELECT TRY_CAST('-2147483648' AS int);
+SELECT TRY_CAST('-2147483649' AS int);
+
+-- '2147483647' is the largest int value
+SELECT TRY_CAST('2147483647' AS int);
+SELECT TRY_CAST('2147483648' AS int);
+
+-- '-9223372036854775808' is the smallest long value
+SELECT TRY_CAST('-9223372036854775808' AS long);
+SELECT TRY_CAST('-9223372036854775809' AS long);
+
+-- '9223372036854775807' is the largest long value
+SELECT TRY_CAST('9223372036854775807' AS long);
+SELECT TRY_CAST('9223372036854775808' AS long);
+
+-- TRY_CAST string to interval and interval to string
+SELECT TRY_CAST('interval 3 month 1 hour' AS interval);
+SELECT TRY_CAST('abc' AS interval);
+
+-- TRY_CAST string to boolean
+select TRY_CAST('true' as boolean);
+select TRY_CAST('false' as boolean);
+select TRY_CAST('abc' as boolean);
+
+-- TRY_CAST string to date
+SELECT TRY_CAST("2021-01-01" AS date);
+SELECT TRY_CAST("2021-101-01" AS date);
+
+-- TRY_CAST string to timestamp
+SELECT TRY_CAST("2021-01-01 00:00:00" AS timestamp);
+SELECT TRY_CAST("2021-101-01 00:00:00" AS timestamp);
\ No newline at end of file
diff --git a/sql/core/src/test/resources/sql-tests/results/try_cast.sql.out 
b/sql/core/src/test/resources/sql-tests/results/try_cast.sql.out
new file mode 100644
index 0000000..810b82f
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/try_cast.sql.out
@@ -0,0 +1,234 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 29
+
+
+-- !query
+SELECT TRY_CAST('1.23' AS int)
+-- !query schema
+struct<CAST(1.23 AS INT):int>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('1.23' AS long)
+-- !query schema
+struct<CAST(1.23 AS BIGINT):bigint>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('-4.56' AS int)
+-- !query schema
+struct<CAST(-4.56 AS INT):int>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('-4.56' AS long)
+-- !query schema
+struct<CAST(-4.56 AS BIGINT):bigint>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('abc' AS int)
+-- !query schema
+struct<CAST(abc AS INT):int>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('abc' AS long)
+-- !query schema
+struct<CAST(abc AS BIGINT):bigint>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('' AS int)
+-- !query schema
+struct<CAST( AS INT):int>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('' AS long)
+-- !query schema
+struct<CAST( AS BIGINT):bigint>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST(NULL AS int)
+-- !query schema
+struct<CAST(NULL AS INT):int>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST(NULL AS long)
+-- !query schema
+struct<CAST(NULL AS BIGINT):bigint>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('123.a' AS int)
+-- !query schema
+struct<CAST(123.a AS INT):int>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('123.a' AS long)
+-- !query schema
+struct<CAST(123.a AS BIGINT):bigint>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('-2147483648' AS int)
+-- !query schema
+struct<CAST(-2147483648 AS INT):int>
+-- !query output
+-2147483648
+
+
+-- !query
+SELECT TRY_CAST('-2147483649' AS int)
+-- !query schema
+struct<CAST(-2147483649 AS INT):int>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('2147483647' AS int)
+-- !query schema
+struct<CAST(2147483647 AS INT):int>
+-- !query output
+2147483647
+
+
+-- !query
+SELECT TRY_CAST('2147483648' AS int)
+-- !query schema
+struct<CAST(2147483648 AS INT):int>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('-9223372036854775808' AS long)
+-- !query schema
+struct<CAST(-9223372036854775808 AS BIGINT):bigint>
+-- !query output
+-9223372036854775808
+
+
+-- !query
+SELECT TRY_CAST('-9223372036854775809' AS long)
+-- !query schema
+struct<CAST(-9223372036854775809 AS BIGINT):bigint>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('9223372036854775807' AS long)
+-- !query schema
+struct<CAST(9223372036854775807 AS BIGINT):bigint>
+-- !query output
+9223372036854775807
+
+
+-- !query
+SELECT TRY_CAST('9223372036854775808' AS long)
+-- !query schema
+struct<CAST(9223372036854775808 AS BIGINT):bigint>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST('interval 3 month 1 hour' AS interval)
+-- !query schema
+struct<CAST(interval 3 month 1 hour AS INTERVAL):interval>
+-- !query output
+3 months 1 hours
+
+
+-- !query
+SELECT TRY_CAST('abc' AS interval)
+-- !query schema
+struct<CAST(abc AS INTERVAL):interval>
+-- !query output
+NULL
+
+
+-- !query
+select TRY_CAST('true' as boolean)
+-- !query schema
+struct<CAST(true AS BOOLEAN):boolean>
+-- !query output
+true
+
+
+-- !query
+select TRY_CAST('false' as boolean)
+-- !query schema
+struct<CAST(false AS BOOLEAN):boolean>
+-- !query output
+false
+
+
+-- !query
+select TRY_CAST('abc' as boolean)
+-- !query schema
+struct<CAST(abc AS BOOLEAN):boolean>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST("2021-01-01" AS date)
+-- !query schema
+struct<CAST(2021-01-01 AS DATE):date>
+-- !query output
+2021-01-01
+
+
+-- !query
+SELECT TRY_CAST("2021-101-01" AS date)
+-- !query schema
+struct<CAST(2021-101-01 AS DATE):date>
+-- !query output
+NULL
+
+
+-- !query
+SELECT TRY_CAST("2021-01-01 00:00:00" AS timestamp)
+-- !query schema
+struct<CAST(2021-01-01 00:00:00 AS TIMESTAMP):timestamp>
+-- !query output
+2021-01-01 00:00:00
+
+
+-- !query
+SELECT TRY_CAST("2021-101-01 00:00:00" AS timestamp)
+-- !query schema
+struct<CAST(2021-101-01 00:00:00 AS TIMESTAMP):timestamp>
+-- !query output
+NULL

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to