cloud-fan commented on a change in pull request #28645:
URL: https://github.com/apache/spark/pull/28645#discussion_r437349734



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2847,6 +2848,45 @@ class Analyzer(
     }
   }
 
+  /**
+   * Resolve the encoders for the UDF by explicitly given the attributes. We 
give the
+   * attributes explicitly in order to handle the case where the data type of 
the input
+   * value is not the same with the internal schema of the encoder, which 
could cause
+   * data loss. For example, the encoder should not cast the input value to 
Decimal(38, 18)
+   * if the actual data type is Decimal(30, 0).
+   *
+   * The resolved encoders then will be used to deserialize the internal row 
to Scala value.
+   */
+  object ResolveEncodersInUDF extends Rule[LogicalPlan] {
+    override def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsUp {
+      case p if !p.resolved => p // Skip unresolved nodes.
+
+      case p => p transformExpressionsUp {
+
+        case udf @ ScalaUDF(_, _, inputs, encoders, _, _, _) if 
encoders.nonEmpty =>
+          val resolvedEncoders = encoders.zipWithIndex.map { case (encOpt, i) 
=>
+            val dataType = inputs(i).dataType
+            if (dataType.isInstanceOf[UserDefinedType[_]]) {

Review comment:
       what about struct/array/map of UDT?

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -2847,6 +2848,45 @@ class Analyzer(
     }
   }
 
+  /**
+   * Resolve the encoders for the UDF by explicitly given the attributes. We 
give the
+   * attributes explicitly in order to handle the case where the data type of 
the input
+   * value is not the same with the internal schema of the encoder, which 
could cause
+   * data loss. For example, the encoder should not cast the input value to 
Decimal(38, 18)
+   * if the actual data type is Decimal(30, 0).
+   *
+   * The resolved encoders then will be used to deserialize the internal row 
to Scala value.
+   */
+  object ResolveEncodersInUDF extends Rule[LogicalPlan] {
+    override def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsUp {
+      case p if !p.resolved => p // Skip unresolved nodes.
+
+      case p => p transformExpressionsUp {
+
+        case udf @ ScalaUDF(_, _, inputs, encoders, _, _, _) if 
encoders.nonEmpty =>
+          val resolvedEncoders = encoders.zipWithIndex.map { case (encOpt, i) 
=>

Review comment:
       maybe to call it `boundEncoders`.

##########
File path: sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
##########
@@ -581,4 +581,69 @@ class UDFSuite extends QueryTest with SharedSparkSession {
       .toDF("col1", "col2")
     checkAnswer(df.select(myUdf(Column("col1"), Column("col2"))), Row(2020) :: 
Nil)
   }
+
+  test("case class as element type of Seq/Array") {
+    val f1 = (s: Seq[TestData]) => s.map(d => d.key * d.value.toInt).sum
+    val myUdf1 = udf(f1)
+    val df1 = Seq(("data", Seq(TestData(50, "2")))).toDF("col1", "col2")
+    checkAnswer(df1.select(myUdf1(Column("col2"))), Row(100) :: Nil)
+
+    val f2 = (s: Array[TestData]) => s.map(d => d.key * d.value.toInt).sum
+    val myUdf2 = udf(f2)
+    val df2 = Seq(("data", Array(TestData(50, "2")))).toDF("col1", "col2")
+    checkAnswer(df2.select(myUdf2(Column("col2"))), Row(100) :: Nil)
+  }
+
+  test("case class as key/value type of Map") {
+    val f1 = (s: Map[TestData, Int]) => s.keys.head.key * 
s.keys.head.value.toInt
+    val myUdf1 = udf(f1)
+    val df1 = Seq(("data", Map(TestData(50, "2") -> 502))).toDF("col1", "col2")
+    checkAnswer(df1.select(myUdf1(Column("col2"))), Row(100) :: Nil)
+
+    val f2 = (s: Map[Int, TestData]) => s.values.head.key * 
s.values.head.value.toInt
+    val myUdf2 = udf(f2)
+    val df2 = Seq(("data", Map(502 -> TestData(50, "2")))).toDF("col1", "col2")
+    checkAnswer(df2.select(myUdf2(Column("col2"))), Row(100) :: Nil)
+
+    val f3 = (s: Map[TestData, TestData]) => s.keys.head.key * 
s.values.head.value.toInt
+    val myUdf3 = udf(f3)
+    val df3 = Seq(("data", Map(TestData(50, "2") -> TestData(50, 
"2")))).toDF("col1", "col2")
+    checkAnswer(df3.select(myUdf3(Column("col2"))), Row(100) :: Nil)
+  }
+
+  test("case class as element of tuple") {
+    val f = (s: (TestData, Int)) => s._1.key * s._2
+    val myUdf = udf(f)
+    val df = Seq(("data", (TestData(50, "2"), 2))).toDF("col1", "col2")
+    checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil)
+  }
+
+  test("case class as generic type of Option") {
+    val f = (o: Option[TestData]) => o.map(t => t.key * t.value.toInt)
+    val myUdf = udf(f)
+    val df = Seq(("data", Some(TestData(50, "2")))).toDF("col1", "col2")
+    checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil)
+  }
+
+  test("more input fields than expect for case class") {
+    val f = (t: TestData2) => t.a * t.b
+    val myUdf = udf(f)
+    val df = Seq(("data", TestData4(50, 2, 2))).toDF("col1", "col2")

Review comment:
       let's avoid creating too many TestData variants.
   
   here we can create a dataframe directly: 
`spark.range(1).select(lit(50).as("a"), ...)`

##########
File path: sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
##########
@@ -581,4 +581,69 @@ class UDFSuite extends QueryTest with SharedSparkSession {
       .toDF("col1", "col2")
     checkAnswer(df.select(myUdf(Column("col1"), Column("col2"))), Row(2020) :: 
Nil)
   }
+
+  test("case class as element type of Seq/Array") {
+    val f1 = (s: Seq[TestData]) => s.map(d => d.key * d.value.toInt).sum
+    val myUdf1 = udf(f1)
+    val df1 = Seq(("data", Seq(TestData(50, "2")))).toDF("col1", "col2")
+    checkAnswer(df1.select(myUdf1(Column("col2"))), Row(100) :: Nil)
+
+    val f2 = (s: Array[TestData]) => s.map(d => d.key * d.value.toInt).sum
+    val myUdf2 = udf(f2)
+    val df2 = Seq(("data", Array(TestData(50, "2")))).toDF("col1", "col2")
+    checkAnswer(df2.select(myUdf2(Column("col2"))), Row(100) :: Nil)
+  }
+
+  test("case class as key/value type of Map") {
+    val f1 = (s: Map[TestData, Int]) => s.keys.head.key * 
s.keys.head.value.toInt
+    val myUdf1 = udf(f1)
+    val df1 = Seq(("data", Map(TestData(50, "2") -> 502))).toDF("col1", "col2")
+    checkAnswer(df1.select(myUdf1(Column("col2"))), Row(100) :: Nil)
+
+    val f2 = (s: Map[Int, TestData]) => s.values.head.key * 
s.values.head.value.toInt
+    val myUdf2 = udf(f2)
+    val df2 = Seq(("data", Map(502 -> TestData(50, "2")))).toDF("col1", "col2")
+    checkAnswer(df2.select(myUdf2(Column("col2"))), Row(100) :: Nil)
+
+    val f3 = (s: Map[TestData, TestData]) => s.keys.head.key * 
s.values.head.value.toInt
+    val myUdf3 = udf(f3)
+    val df3 = Seq(("data", Map(TestData(50, "2") -> TestData(50, 
"2")))).toDF("col1", "col2")
+    checkAnswer(df3.select(myUdf3(Column("col2"))), Row(100) :: Nil)
+  }
+
+  test("case class as element of tuple") {
+    val f = (s: (TestData, Int)) => s._1.key * s._2
+    val myUdf = udf(f)
+    val df = Seq(("data", (TestData(50, "2"), 2))).toDF("col1", "col2")
+    checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil)
+  }
+
+  test("case class as generic type of Option") {
+    val f = (o: Option[TestData]) => o.map(t => t.key * t.value.toInt)
+    val myUdf = udf(f)
+    val df = Seq(("data", Some(TestData(50, "2")))).toDF("col1", "col2")
+    checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil)
+  }
+
+  test("more input fields than expect for case class") {
+    val f = (t: TestData2) => t.a * t.b
+    val myUdf = udf(f)
+    val df = Seq(("data", TestData4(50, 2, 2))).toDF("col1", "col2")
+    checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil)
+  }
+
+  test("less input fields than expect for case class") {
+    val f = (t: TestData4) => t.a * t.b * t.c
+    val myUdf = udf(f)
+    val df = Seq(("data", TestData2(50, 2))).toDF("col1", "col2")
+    val error = intercept[AnalysisException] (df.select(myUdf(Column("col2"))))
+    assert(error.getMessage.contains("cannot resolve '`c`' given input 
columns: [a, b]"))
+  }
+
+  test("wrong order of input fields for case class") {
+    val f = (t: TestData) => t.key * t.value.toInt
+    val myUdf = udf(f)
+    val df = Seq(("data", TestData5("2", 50))).toDF("col1", "col2")

Review comment:
       ditto




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to