This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new d0a7c6f [SPARK-31750][SQL] Eliminate UpCast if child's dataType is
DecimalType
d0a7c6f is described below
commit d0a7c6fdc1b90aa372c004b66258865ab37c6c3d
Author: yi.wu <[email protected]>
AuthorDate: Wed May 20 11:00:58 2020 +0900
[SPARK-31750][SQL] Eliminate UpCast if child's dataType is DecimalType
### What changes were proposed in this pull request?
Eliminate the `UpCast` if it's child data type is already decimal type.
### Why are the changes needed?
While deserializing internal `Decimal` value to external
`BigDecimal`(Java/Scala) value, Spark should also respect `Decimal`'s precision
and scale, otherwise it will cause precision lost and look weird in some cases,
e.g.:
```
sql("select cast(11111111111111111111111111111111111111 as decimal(38, 0))
as d")
.write.mode("overwrite")
.parquet(f.getAbsolutePath)
// can fail
spark.read.parquet(f.getAbsolutePath).as[BigDecimal]
```
```
[info] org.apache.spark.sql.AnalysisException: Cannot up cast `d` from
decimal(38,0) to decimal(38,18).
[info] The type path of the target object is:
[info] - root class: "scala.math.BigDecimal"
[info] You can either add an explicit cast to the input data or choose a
higher precision type of the field in the target object;
[info] at
org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$.org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveUpCast$$fail(Analyzer.scala:3060)
[info] at
org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$$anonfun$apply$33$$anonfun$applyOrElse$174.applyOrElse(Analyzer.scala:3087)
[info] at
org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$$anonfun$apply$33$$anonfun$applyOrElse$174.applyOrElse(Analyzer.scala:3071)
[info] at
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDown$1(TreeNode.scala:309)
[info] at
org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:72)
[info] at
org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:309)
[info] at
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDown$3(TreeNode.scala:314)
```
### Does this PR introduce _any_ user-facing change?
Yes, for cases(cause precision lost) mentioned above will fail before this
change but run successfully after this change.
### How was this patch tested?
Added tests.
Closes #28572 from Ngone51/fix_encoder.
Authored-by: yi.wu <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
---
.../sql/catalyst/DeserializerBuildHelper.scala | 4 ++++
.../spark/sql/catalyst/analysis/Analyzer.scala | 24 +++++++++++++++++-----
.../spark/sql/catalyst/expressions/Cast.scala | 10 ++++++++-
.../catalyst/encoders/EncoderResolutionSuite.scala | 11 ++++++++--
.../org/apache/spark/sql/DataFrameSuite.scala | 11 ++++++++++
5 files changed, 52 insertions(+), 8 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index e55c25c..701e4e3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -161,6 +161,10 @@ object DeserializerBuildHelper {
case _: StructType => expr
case _: ArrayType => expr
case _: MapType => expr
+ case _: DecimalType =>
+ // For Scala/Java `BigDecimal`, we accept decimal types of any valid
precision/scale.
+ // Here we use the `DecimalType` object to indicate it.
+ UpCast(expr, DecimalType, walkedTypePath.getPaths)
case _ => UpCast(expr, expected, walkedTypePath.getPaths)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 0d71aee..654cf42 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3142,15 +3142,29 @@ class Analyzer(
case p => p transformExpressions {
case u @ UpCast(child, _, _) if !child.resolved => u
- case UpCast(child, dt: AtomicType, _)
+ case UpCast(_, target, _) if target != DecimalType &&
!target.isInstanceOf[DataType] =>
+ throw new AnalysisException(
+ s"UpCast only support DecimalType as AbstractDataType yet, but
got: $target")
+
+ case UpCast(child, target, walkedTypePath) if target == DecimalType
+ && child.dataType.isInstanceOf[DecimalType] =>
+ assert(walkedTypePath.nonEmpty,
+ "object DecimalType should only be used inside ExpressionEncoder")
+
+ // SPARK-31750: if we want to upcast to the general decimal type,
and the `child` is
+ // already decimal type, we can remove the `Upcast` and accept any
precision/scale.
+ // This can happen for cases like
`spark.read.parquet("/tmp/file").as[BigDecimal]`.
+ child
+
+ case UpCast(child, target: AtomicType, _)
if SQLConf.get.getConf(SQLConf.LEGACY_LOOSE_UPCAST) &&
child.dataType == StringType =>
- Cast(child, dt.asNullable)
+ Cast(child, target.asNullable)
- case UpCast(child, dataType, walkedTypePath) if
!Cast.canUpCast(child.dataType, dataType) =>
- fail(child, dataType, walkedTypePath)
+ case u @ UpCast(child, _, walkedTypePath) if
!Cast.canUpCast(child.dataType, u.dataType) =>
+ fail(child, u.dataType, walkedTypePath)
- case UpCast(child, dataType, _) => Cast(child, dataType.asNullable)
+ case u @ UpCast(child, _, _) => Cast(child, u.dataType.asNullable)
}
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 9bc86d4..ef70915 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -1742,8 +1742,16 @@ case class AnsiCast(child: Expression, dataType:
DataType, timeZoneId: Option[St
/**
* Cast the child expression to the target data type, but will throw error if
the cast might
* truncate, e.g. long -> int, timestamp -> data.
+ *
+ * Note: `target` is `AbstractDataType`, so that we can put `object
DecimalType`, which means
+ * we accept `DecimalType` with any valid precision/scale.
*/
-case class UpCast(child: Expression, dataType: DataType, walkedTypePath:
Seq[String] = Nil)
+case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath:
Seq[String] = Nil)
extends UnaryExpression with Unevaluable {
override lazy val resolved = false
+
+ def dataType: DataType = target match {
+ case DecimalType => DecimalType.SYSTEM_DEFAULT
+ case _ => target.asInstanceOf[DataType]
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 48f4ef5..577814b 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -22,9 +22,9 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference}
import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -247,6 +247,13 @@ class EncoderResolutionSuite extends PlanTest {
""".stripMargin.trim + " of the field in the target object")
}
+ test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") {
+ val encoder = ExpressionEncoder[Seq[BigDecimal]]
+ val attr = Seq(AttributeReference("a", ArrayType(DecimalType(38, 0)))())
+ // Before SPARK-31750, it will fail because Decimal(38, 0) can not be
casted to Decimal(38, 18)
+ testFromRow(encoder, attr,
InternalRow(ArrayData.toArrayData(Array(Decimal(1.0)))))
+ }
+
// test for leaf types
castSuccess[Int, Long]
castSuccess[java.sql.Date, java.sql.Timestamp]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 4e91a7c..954a4bd9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2439,6 +2439,17 @@ class DataFrameSuite extends QueryTest
val nestedDecArray = Array(decSpark)
checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava))))
}
+
+ test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") {
+ withTempPath { f =>
+ sql("select cast(1 as decimal(38, 0)) as d")
+ .write.mode("overwrite")
+ .parquet(f.getAbsolutePath)
+
+ val df = spark.read.parquet(f.getAbsolutePath).as[BigDecimal]
+ assert(df.schema === new StructType().add(StructField("d",
DecimalType(38, 0))))
+ }
+ }
}
case class GroupByKey(a: Int, b: Int)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]