This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5799d18f99f [SPARK-41991][SQL] `CheckOverflowInTableInsert` should 
accept ExpressionProxy as child
5799d18f99f is described below

commit 5799d18f99fb44e254c693958e868ce0c84bcb42
Author: Bruce Robbins <[email protected]>
AuthorDate: Fri Jan 13 09:41:12 2023 +0900

    [SPARK-41991][SQL] `CheckOverflowInTableInsert` should accept 
ExpressionProxy as child
    
    ### What changes were proposed in this pull request?
    
    Change `CheckOverflowInTableInsert` to accept a `Cast` wrapped by an 
`ExpressionProxy` as a child.
    
    ### Why are the changes needed?
    
    This insert statement fails (in interpreted mode):
    ```
    drop table if exists tbl1;
    create table tbl1 (a int, b int) using parquet;
    
    set spark.sql.codegen.wholeStage=false;
    set spark.sql.codegen.factoryMode=NO_CODEGEN;
    
    insert into tbl1
    select id as a, id as b
    from range(1, 5);
    ```
    It gets the following exception:
    ```
    java.lang.ClassCastException: 
org.apache.spark.sql.catalyst.expressions.ExpressionProxy cannot be cast to 
org.apache.spark.sql.catalyst.expressions.Cast
            at 
org.apache.spark.sql.catalyst.expressions.CheckOverflowInTableInsert.withNewChildInternal(Cast.scala:2514)
            at 
org.apache.spark.sql.catalyst.expressions.CheckOverflowInTableInsert.withNewChildInternal(Cast.scala:2512)
    ```
    The query produces 2 bigint values, but the table's schema expects 2 int 
values, so Spark wraps each output field with a `Cast`.
    
    Later, in `InterpretedUnsafeProjection`, `prepareExpressions` tries to wrap 
the two `Cast` expressions with an `ExpressionProxy`. However, the parent 
expression of each `Cast` is a `CheckOverflowInTableInsert` expression, which 
does not accept `ExpressionProxy` as a child.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New unit tests.
    
    Closes #39518 from bersprockets/subexpr_elim_issue.
    
    Authored-by: Bruce Robbins <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../spark/sql/catalyst/expressions/Cast.scala      | 33 ++++++++++++++++++----
 .../sql/catalyst/expressions/CastSuiteBase.scala   | 28 +++++++++++++++++-
 .../SubExprEvaluationRuntimeSuite.scala            | 12 ++++++++
 .../sql/errors/QueryExecutionAnsiErrorsSuite.scala | 18 +++++++++++-
 4 files changed, 83 insertions(+), 8 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d7a3952e634..b72ba3ea8a0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -21,7 +21,7 @@ import java.time.{ZoneId, ZoneOffset}
 import java.util.Locale
 import java.util.concurrent.TimeUnit._
 
-import org.apache.spark.SparkArithmeticException
+import org.apache.spark.{SparkArithmeticException, SparkException}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -2509,21 +2509,42 @@ case class UpCast(child: Expression, target: 
AbstractDataType, walkedTypePath: S
  * Casting a numeric value as another numeric type in store assignment. It can 
capture the
  * arithmetic errors and show proper error messages to users.
  */
-case class CheckOverflowInTableInsert(child: Cast, columnName: String) extends 
UnaryExpression {
-  override protected def withNewChildInternal(newChild: Expression): 
Expression =
-    copy(child = newChild.asInstanceOf[Cast])
+case class CheckOverflowInTableInsert(child: Expression, columnName: String)
+    extends UnaryExpression {
+  checkChild(child)
+
+  private def checkChild(child: Expression): Unit = child match {
+    case _: Cast =>
+    case ExpressionProxy(c, _, _) if c.isInstanceOf[Cast] =>
+    case _ =>
+      throw SparkException.internalError("Child is not Cast or ExpressionProxy 
of Cast")
+  }
+
+  override protected def withNewChildInternal(newChild: Expression): 
Expression = {
+    checkChild(newChild)
+    copy(child = newChild)
+  }
+
+  private def getCast: Cast = child match {
+    case c: Cast =>
+      c
+    case ExpressionProxy(c, _, _) =>
+      c.asInstanceOf[Cast]
+  }
 
   override def eval(input: InternalRow): Any = try {
     child.eval(input)
   } catch {
     case e: SparkArithmeticException =>
+      val cast = getCast
       throw QueryExecutionErrors.castingCauseOverflowErrorInTableInsert(
-        child.child.dataType,
-        child.dataType,
+        cast.child.dataType,
+        cast.dataType,
         columnName)
   }
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
+    val child = getCast
     val childGen = child.genCode(ctx)
     val exceptionClass = classOf[SparkArithmeticException].getCanonicalName
     val fromDt =
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index bad85ca4176..ca9f43adc1f 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -24,7 +24,7 @@ import java.util.{Calendar, Locale, TimeZone}
 
 import scala.collection.parallel.immutable.ParVector
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -1391,4 +1391,30 @@ abstract class CastSuiteBase extends SparkFunSuite with 
ExpressionEvalHelper {
     assert(expr.sql == cast.sql)
     assert(expr.toString == cast.toString)
   }
+
+  test("SPARK-41991: CheckOverflowInTableInsert child must be Cast or 
ExpressionProxy of Cast") {
+    val runtime = new SubExprEvaluationRuntime(1)
+    val cast = Cast(Literal(1.0), IntegerType)
+    val expr = CheckOverflowInTableInsert(cast, "column_1")
+    val proxy = ExpressionProxy(Literal(1.0), 0, runtime)
+    checkError(
+      exception = intercept[SparkException] {
+        expr.withNewChildrenInternal(IndexedSeq(proxy))
+      },
+      errorClass = "INTERNAL_ERROR",
+      parameters = Map(
+        "message" -> "Child is not Cast or ExpressionProxy of Cast"
+      )
+    )
+
+    checkError(
+      exception = intercept[SparkException] {
+        expr.withNewChildrenInternal(IndexedSeq(Literal(1)))
+      },
+      errorClass = "INTERNAL_ERROR",
+      parameters = Map(
+        "message" -> "Child is not Cast or ExpressionProxy of Cast"
+      )
+    )
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala
index f8dca266a62..db3ebec82de 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala
@@ -17,6 +17,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.LongType
 
 class SubExprEvaluationRuntimeSuite extends SparkFunSuite {
 
@@ -117,4 +118,15 @@ class SubExprEvaluationRuntimeSuite extends SparkFunSuite {
     assert(proxys.size == 2)
     assert(proxys.forall(_.child.semanticEquals(mul2_1)))
   }
+
+  test("SPARK-41991: CheckOverflowInTableInsert with ExpressionProxy child") {
+    val runtime = new SubExprEvaluationRuntime(1)
+    val proxy = ExpressionProxy(Cast(Literal.apply(1), LongType), 0, runtime)
+    val checkOverflow = CheckOverflowInTableInsert(Cast(Literal.apply(1), 
LongType), "col")
+      .withNewChildrenInternal(IndexedSeq(proxy))
+    assert(runtime.cache.size() == 0)
+    checkOverflow.eval()
+    assert(runtime.cache.size() == 1)
+    assert(runtime.cache.get(proxy) == ResultProxy(1L))
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
index 1ec852d662f..dc60fa8b025 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.errors
 
 import org.apache.spark._
 import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.catalyst.expressions.{Cast, 
CheckOverflowInTableInsert, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Cast, 
CheckOverflowInTableInsert, ExpressionProxy, Literal, SubExprEvaluationRuntime}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.ByteType
@@ -192,4 +192,20 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest
         "columnName" -> "`col`")
     )
   }
+
+  test("SPARK-41991: interpreted CheckOverflowInTableInsert with 
ExpressionProxy should " +
+    "throw an exception") {
+    val runtime = new SubExprEvaluationRuntime(1)
+    val proxy = ExpressionProxy(Cast(Literal.apply(12345678901234567890D), 
ByteType), 0, runtime)
+    checkError(
+      exception = intercept[SparkArithmeticException] {
+        CheckOverflowInTableInsert(proxy, "col").eval(null)
+      }.asInstanceOf[SparkThrowable],
+      errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT",
+      parameters = Map(
+        "sourceType" -> "\"DOUBLE\"",
+        "targetType" -> ("\"TINYINT\""),
+        "columnName" -> "`col`")
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to