This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 98c0ca7  [SPARK-26308][SQL] Avoid cast of decimals for ScalaUDF
98c0ca7 is described below

commit 98c0ca78610ccf62784081353584717c62285485
Author: Marco Gaido <marcogaid...@gmail.com>
AuthorDate: Thu Dec 20 14:17:44 2018 +0800

    [SPARK-26308][SQL] Avoid cast of decimals for ScalaUDF
    
    ## What changes were proposed in this pull request?
    
    Currently, when we infer the schema for scala/java decimals, we return as 
data type the `SYSTEM_DEFAULT` implementation, ie. the decimal type with 
precision 38 and scale 18. But this is not right, as we know nothing about the 
right precision and scale and these values can be not enough to store the data. 
This problem arises in particular with UDF, where we cast all the input of type 
`DecimalType` to a `DecimalType(38, 18)`: in case this is not enough, null is 
returned as input for the UDF.
    
    The PR defines a custom handling for casting to the expected data types for 
ScalaUDF: the decimal precision and scale is picked from the input, so no 
casting to different and maybe wrong percision and scale happens.
    
    ## How was this patch tested?
    
    added UTs
    
    Closes #23308 from mgaido91/SPARK-26308.
    
    Authored-by: Marco Gaido <marcogaid...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/analysis/TypeCoercion.scala | 31 +++++++++++++++++++++
 .../spark/sql/catalyst/expressions/ScalaUDF.scala  |  2 +-
 .../test/scala/org/apache/spark/sql/UDFSuite.scala | 32 +++++++++++++++++++++-
 3 files changed, 63 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 133fa11..1706b3e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -879,6 +879,37 @@ object TypeCoercion {
           }
         }
         e.withNewChildren(children)
+
+      case udf: ScalaUDF if udf.inputTypes.nonEmpty =>
+        val children = udf.children.zip(udf.inputTypes).map { case (in, 
expected) =>
+          implicitCast(in, udfInputToCastType(in.dataType, 
expected)).getOrElse(in)
+        }
+        udf.withNewChildren(children)
+    }
+
+    private def udfInputToCastType(input: DataType, expectedType: DataType): 
DataType = {
+      (input, expectedType) match {
+        // SPARK-26308: avoid casting to an arbitrary precision and scale for 
decimals. Please note
+        // that precision and scale cannot be inferred properly for a ScalaUDF 
because, when it is
+        // created, it is not bound to any column. So here the precision and 
scale of the input
+        // column is used.
+        case (in: DecimalType, _: DecimalType) => in
+        case (ArrayType(dtIn, _), ArrayType(dtExp, nullableExp)) =>
+          ArrayType(udfInputToCastType(dtIn, dtExp), nullableExp)
+        case (MapType(keyDtIn, valueDtIn, _), MapType(keyDtExp, valueDtExp, 
nullableExp)) =>
+          MapType(udfInputToCastType(keyDtIn, keyDtExp),
+            udfInputToCastType(valueDtIn, valueDtExp),
+            nullableExp)
+        case (StructType(fieldsIn), StructType(fieldsExp)) =>
+          val fieldTypes =
+            fieldsIn.map(_.dataType).zip(fieldsExp.map(_.dataType)).map { case 
(dtIn, dtExp) =>
+              udfInputToCastType(dtIn, dtExp)
+            }
+          StructType(fieldsExp.zip(fieldTypes).map { case (field, newDt) =>
+            field.copy(dataType = newDt)
+          })
+        case (_, other) => other
+      }
     }
 
     /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index fae90ca..a23aaa3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -52,7 +52,7 @@ case class ScalaUDF(
     udfName: Option[String] = None,
     nullable: Boolean = true,
     udfDeterministic: Boolean = true)
-  extends Expression with ImplicitCastInputTypes with NonSQLExpression with 
UserDefinedExpression {
+  extends Expression with NonSQLExpression with UserDefinedExpression {
 
   // The constructor for SPARK 2.1 and 2.2
   def this(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 20dcefa..a26d306 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import java.math.BigDecimal
+
 import org.apache.spark.sql.api.java._
 import org.apache.spark.sql.catalyst.plans.logical.Project
 import org.apache.spark.sql.execution.QueryExecution
@@ -26,7 +28,7 @@ import 
org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationComm
 import org.apache.spark.sql.functions.{lit, udf}
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.test.SQLTestData._
-import org.apache.spark.sql.types.{DataTypes, DoubleType}
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.QueryExecutionListener
 
 
@@ -420,4 +422,32 @@ class UDFSuite extends QueryTest with SharedSQLContext {
       checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null")))
     }
   }
+
+  test("SPARK-26308: udf with decimal") {
+    val df1 = spark.createDataFrame(
+      sparkContext.parallelize(Seq(Row(new 
BigDecimal("2011000000000002456556")))),
+      StructType(Seq(StructField("col1", DecimalType(30, 0)))))
+    val udf1 = org.apache.spark.sql.functions.udf((value: BigDecimal) => {
+      if (value == null) null else value.toBigInteger.toString
+    })
+    checkAnswer(df1.select(udf1(df1.col("col1"))), 
Seq(Row("2011000000000002456556")))
+  }
+
+  test("SPARK-26308: udf with complex types of decimal") {
+    val df1 = spark.createDataFrame(
+      sparkContext.parallelize(Seq(Row(Array(new 
BigDecimal("2011000000000002456556"))))),
+      StructType(Seq(StructField("col1", ArrayType(DecimalType(30, 0))))))
+    val udf1 = org.apache.spark.sql.functions.udf((arr: Seq[BigDecimal]) => {
+      arr.map(value => if (value == null) null else 
value.toBigInteger.toString)
+    })
+    checkAnswer(df1.select(udf1($"col1")), 
Seq(Row(Array("2011000000000002456556"))))
+
+    val df2 = spark.createDataFrame(
+      sparkContext.parallelize(Seq(Row(Map("a" -> new 
BigDecimal("2011000000000002456556"))))),
+      StructType(Seq(StructField("col1", MapType(StringType, DecimalType(30, 
0))))))
+    val udf2 = org.apache.spark.sql.functions.udf((map: Map[String, 
BigDecimal]) => {
+      map.mapValues(value => if (value == null) null else 
value.toBigInteger.toString)
+    })
+    checkAnswer(df2.select(udf2($"col1")), Seq(Row(Map("a" -> 
"2011000000000002456556"))))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to