This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 98c0ca7 [SPARK-26308][SQL] Avoid cast of decimals for ScalaUDF
98c0ca7 is described below
commit 98c0ca78610ccf62784081353584717c62285485
Author: Marco Gaido <[email protected]>
AuthorDate: Thu Dec 20 14:17:44 2018 +0800
[SPARK-26308][SQL] Avoid cast of decimals for ScalaUDF
## What changes were proposed in this pull request?
Currently, when we infer the schema for scala/java decimals, we return as
data type the `SYSTEM_DEFAULT` implementation, ie. the decimal type with
precision 38 and scale 18. But this is not right, as we know nothing about the
right precision and scale and these values can be not enough to store the data.
This problem arises in particular with UDF, where we cast all the input of type
`DecimalType` to a `DecimalType(38, 18)`: in case this is not enough, null is
returned as input for the UDF.
The PR defines a custom handling for casting to the expected data types for
ScalaUDF: the decimal precision and scale is picked from the input, so no
casting to different and maybe wrong percision and scale happens.
## How was this patch tested?
added UTs
Closes #23308 from mgaido91/SPARK-26308.
Authored-by: Marco Gaido <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/analysis/TypeCoercion.scala | 31 +++++++++++++++++++++
.../spark/sql/catalyst/expressions/ScalaUDF.scala | 2 +-
.../test/scala/org/apache/spark/sql/UDFSuite.scala | 32 +++++++++++++++++++++-
3 files changed, 63 insertions(+), 2 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 133fa11..1706b3e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -879,6 +879,37 @@ object TypeCoercion {
}
}
e.withNewChildren(children)
+
+ case udf: ScalaUDF if udf.inputTypes.nonEmpty =>
+ val children = udf.children.zip(udf.inputTypes).map { case (in,
expected) =>
+ implicitCast(in, udfInputToCastType(in.dataType,
expected)).getOrElse(in)
+ }
+ udf.withNewChildren(children)
+ }
+
+ private def udfInputToCastType(input: DataType, expectedType: DataType):
DataType = {
+ (input, expectedType) match {
+ // SPARK-26308: avoid casting to an arbitrary precision and scale for
decimals. Please note
+ // that precision and scale cannot be inferred properly for a ScalaUDF
because, when it is
+ // created, it is not bound to any column. So here the precision and
scale of the input
+ // column is used.
+ case (in: DecimalType, _: DecimalType) => in
+ case (ArrayType(dtIn, _), ArrayType(dtExp, nullableExp)) =>
+ ArrayType(udfInputToCastType(dtIn, dtExp), nullableExp)
+ case (MapType(keyDtIn, valueDtIn, _), MapType(keyDtExp, valueDtExp,
nullableExp)) =>
+ MapType(udfInputToCastType(keyDtIn, keyDtExp),
+ udfInputToCastType(valueDtIn, valueDtExp),
+ nullableExp)
+ case (StructType(fieldsIn), StructType(fieldsExp)) =>
+ val fieldTypes =
+ fieldsIn.map(_.dataType).zip(fieldsExp.map(_.dataType)).map { case
(dtIn, dtExp) =>
+ udfInputToCastType(dtIn, dtExp)
+ }
+ StructType(fieldsExp.zip(fieldTypes).map { case (field, newDt) =>
+ field.copy(dataType = newDt)
+ })
+ case (_, other) => other
+ }
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index fae90ca..a23aaa3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -52,7 +52,7 @@ case class ScalaUDF(
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true)
- extends Expression with ImplicitCastInputTypes with NonSQLExpression with
UserDefinedExpression {
+ extends Expression with NonSQLExpression with UserDefinedExpression {
// The constructor for SPARK 2.1 and 2.2
def this(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 20dcefa..a26d306 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import java.math.BigDecimal
+
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.QueryExecution
@@ -26,7 +28,7 @@ import
org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationComm
import org.apache.spark.sql.functions.{lit, udf}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
-import org.apache.spark.sql.types.{DataTypes, DoubleType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.util.QueryExecutionListener
@@ -420,4 +422,32 @@ class UDFSuite extends QueryTest with SharedSQLContext {
checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null")))
}
}
+
+ test("SPARK-26308: udf with decimal") {
+ val df1 = spark.createDataFrame(
+ sparkContext.parallelize(Seq(Row(new
BigDecimal("2011000000000002456556")))),
+ StructType(Seq(StructField("col1", DecimalType(30, 0)))))
+ val udf1 = org.apache.spark.sql.functions.udf((value: BigDecimal) => {
+ if (value == null) null else value.toBigInteger.toString
+ })
+ checkAnswer(df1.select(udf1(df1.col("col1"))),
Seq(Row("2011000000000002456556")))
+ }
+
+ test("SPARK-26308: udf with complex types of decimal") {
+ val df1 = spark.createDataFrame(
+ sparkContext.parallelize(Seq(Row(Array(new
BigDecimal("2011000000000002456556"))))),
+ StructType(Seq(StructField("col1", ArrayType(DecimalType(30, 0))))))
+ val udf1 = org.apache.spark.sql.functions.udf((arr: Seq[BigDecimal]) => {
+ arr.map(value => if (value == null) null else
value.toBigInteger.toString)
+ })
+ checkAnswer(df1.select(udf1($"col1")),
Seq(Row(Array("2011000000000002456556"))))
+
+ val df2 = spark.createDataFrame(
+ sparkContext.parallelize(Seq(Row(Map("a" -> new
BigDecimal("2011000000000002456556"))))),
+ StructType(Seq(StructField("col1", MapType(StringType, DecimalType(30,
0))))))
+ val udf2 = org.apache.spark.sql.functions.udf((map: Map[String,
BigDecimal]) => {
+ map.mapValues(value => if (value == null) null else
value.toBigInteger.toString)
+ })
+ checkAnswer(df2.select(udf2($"col1")), Seq(Row(Map("a" ->
"2011000000000002456556"))))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]