This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5799d18f99f [SPARK-41991][SQL] `CheckOverflowInTableInsert` should
accept ExpressionProxy as child
5799d18f99f is described below
commit 5799d18f99fb44e254c693958e868ce0c84bcb42
Author: Bruce Robbins <[email protected]>
AuthorDate: Fri Jan 13 09:41:12 2023 +0900
[SPARK-41991][SQL] `CheckOverflowInTableInsert` should accept
ExpressionProxy as child
### What changes were proposed in this pull request?
Change `CheckOverflowInTableInsert` to accept a `Cast` wrapped by an
`ExpressionProxy` as a child.
### Why are the changes needed?
This insert statement fails (in interpreted mode):
```
drop table if exists tbl1;
create table tbl1 (a int, b int) using parquet;
set spark.sql.codegen.wholeStage=false;
set spark.sql.codegen.factoryMode=NO_CODEGEN;
insert into tbl1
select id as a, id as b
from range(1, 5);
```
It gets the following exception:
```
java.lang.ClassCastException:
org.apache.spark.sql.catalyst.expressions.ExpressionProxy cannot be cast to
org.apache.spark.sql.catalyst.expressions.Cast
at
org.apache.spark.sql.catalyst.expressions.CheckOverflowInTableInsert.withNewChildInternal(Cast.scala:2514)
at
org.apache.spark.sql.catalyst.expressions.CheckOverflowInTableInsert.withNewChildInternal(Cast.scala:2512)
```
The query produces 2 bigint values, but the table's schema expects 2 int
values, so Spark wraps each output field with a `Cast`.
Later, in `InterpretedUnsafeProjection`, `prepareExpressions` tries to wrap
the two `Cast` expressions with an `ExpressionProxy`. However, the parent
expression of each `Cast` is a `CheckOverflowInTableInsert` expression, which
does not accept `ExpressionProxy` as a child.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New unit tests.
Closes #39518 from bersprockets/subexpr_elim_issue.
Authored-by: Bruce Robbins <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../spark/sql/catalyst/expressions/Cast.scala | 33 ++++++++++++++++++----
.../sql/catalyst/expressions/CastSuiteBase.scala | 28 +++++++++++++++++-
.../SubExprEvaluationRuntimeSuite.scala | 12 ++++++++
.../sql/errors/QueryExecutionAnsiErrorsSuite.scala | 18 +++++++++++-
4 files changed, 83 insertions(+), 8 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d7a3952e634..b72ba3ea8a0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -21,7 +21,7 @@ import java.time.{ZoneId, ZoneOffset}
import java.util.Locale
import java.util.concurrent.TimeUnit._
-import org.apache.spark.SparkArithmeticException
+import org.apache.spark.{SparkArithmeticException, SparkException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -2509,21 +2509,42 @@ case class UpCast(child: Expression, target:
AbstractDataType, walkedTypePath: S
* Casting a numeric value as another numeric type in store assignment. It can
capture the
* arithmetic errors and show proper error messages to users.
*/
-case class CheckOverflowInTableInsert(child: Cast, columnName: String) extends
UnaryExpression {
- override protected def withNewChildInternal(newChild: Expression):
Expression =
- copy(child = newChild.asInstanceOf[Cast])
+case class CheckOverflowInTableInsert(child: Expression, columnName: String)
+ extends UnaryExpression {
+ checkChild(child)
+
+ private def checkChild(child: Expression): Unit = child match {
+ case _: Cast =>
+ case ExpressionProxy(c, _, _) if c.isInstanceOf[Cast] =>
+ case _ =>
+ throw SparkException.internalError("Child is not Cast or ExpressionProxy
of Cast")
+ }
+
+ override protected def withNewChildInternal(newChild: Expression):
Expression = {
+ checkChild(newChild)
+ copy(child = newChild)
+ }
+
+ private def getCast: Cast = child match {
+ case c: Cast =>
+ c
+ case ExpressionProxy(c, _, _) =>
+ c.asInstanceOf[Cast]
+ }
override def eval(input: InternalRow): Any = try {
child.eval(input)
} catch {
case e: SparkArithmeticException =>
+ val cast = getCast
throw QueryExecutionErrors.castingCauseOverflowErrorInTableInsert(
- child.child.dataType,
- child.dataType,
+ cast.child.dataType,
+ cast.dataType,
columnName)
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
+ val child = getCast
val childGen = child.genCode(ctx)
val exceptionClass = classOf[SparkArithmeticException].getCanonicalName
val fromDt =
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index bad85ca4176..ca9f43adc1f 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -24,7 +24,7 @@ import java.util.{Calendar, Locale, TimeZone}
import scala.collection.parallel.immutable.ParVector
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -1391,4 +1391,30 @@ abstract class CastSuiteBase extends SparkFunSuite with
ExpressionEvalHelper {
assert(expr.sql == cast.sql)
assert(expr.toString == cast.toString)
}
+
+ test("SPARK-41991: CheckOverflowInTableInsert child must be Cast or
ExpressionProxy of Cast") {
+ val runtime = new SubExprEvaluationRuntime(1)
+ val cast = Cast(Literal(1.0), IntegerType)
+ val expr = CheckOverflowInTableInsert(cast, "column_1")
+ val proxy = ExpressionProxy(Literal(1.0), 0, runtime)
+ checkError(
+ exception = intercept[SparkException] {
+ expr.withNewChildrenInternal(IndexedSeq(proxy))
+ },
+ errorClass = "INTERNAL_ERROR",
+ parameters = Map(
+ "message" -> "Child is not Cast or ExpressionProxy of Cast"
+ )
+ )
+
+ checkError(
+ exception = intercept[SparkException] {
+ expr.withNewChildrenInternal(IndexedSeq(Literal(1)))
+ },
+ errorClass = "INTERNAL_ERROR",
+ parameters = Map(
+ "message" -> "Child is not Cast or ExpressionProxy of Cast"
+ )
+ )
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala
index f8dca266a62..db3ebec82de 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntimeSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.LongType
class SubExprEvaluationRuntimeSuite extends SparkFunSuite {
@@ -117,4 +118,15 @@ class SubExprEvaluationRuntimeSuite extends SparkFunSuite {
assert(proxys.size == 2)
assert(proxys.forall(_.child.semanticEquals(mul2_1)))
}
+
+ test("SPARK-41991: CheckOverflowInTableInsert with ExpressionProxy child") {
+ val runtime = new SubExprEvaluationRuntime(1)
+ val proxy = ExpressionProxy(Cast(Literal.apply(1), LongType), 0, runtime)
+ val checkOverflow = CheckOverflowInTableInsert(Cast(Literal.apply(1),
LongType), "col")
+ .withNewChildrenInternal(IndexedSeq(proxy))
+ assert(runtime.cache.size() == 0)
+ checkOverflow.eval()
+ assert(runtime.cache.size() == 1)
+ assert(runtime.cache.get(proxy) == ResultProxy(1L))
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
index 1ec852d662f..dc60fa8b025 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.errors
import org.apache.spark._
import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.catalyst.expressions.{Cast,
CheckOverflowInTableInsert, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Cast,
CheckOverflowInTableInsert, ExpressionProxy, Literal, SubExprEvaluationRuntime}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.ByteType
@@ -192,4 +192,20 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest
"columnName" -> "`col`")
)
}
+
+ test("SPARK-41991: interpreted CheckOverflowInTableInsert with
ExpressionProxy should " +
+ "throw an exception") {
+ val runtime = new SubExprEvaluationRuntime(1)
+ val proxy = ExpressionProxy(Cast(Literal.apply(12345678901234567890D),
ByteType), 0, runtime)
+ checkError(
+ exception = intercept[SparkArithmeticException] {
+ CheckOverflowInTableInsert(proxy, "col").eval(null)
+ }.asInstanceOf[SparkThrowable],
+ errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT",
+ parameters = Map(
+ "sourceType" -> "\"DOUBLE\"",
+ "targetType" -> ("\"TINYINT\""),
+ "columnName" -> "`col`")
+ )
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]