cloud-fan commented on a change in pull request #28645:
URL: https://github.com/apache/spark/pull/28645#discussion_r442631441



##########
File path: sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
##########
@@ -581,4 +581,99 @@ class UDFSuite extends QueryTest with SharedSparkSession {
       .toDF("col1", "col2")
     checkAnswer(df.select(myUdf(Column("col1"), Column("col2"))), Row(2020) :: 
Nil)
   }
+
+  test("case class as element type of Seq/Array") {
+    val f1 = (s: Seq[TestData]) => s.map(d => d.key * d.value.toInt).sum
+    val myUdf1 = udf(f1)
+    val df1 = Seq(("data", Seq(TestData(50, "2")))).toDF("col1", "col2")
+    checkAnswer(df1.select(myUdf1(Column("col2"))), Row(100) :: Nil)
+
+    val f2 = (s: Array[TestData]) => s.map(d => d.key * d.value.toInt).sum
+    val myUdf2 = udf(f2)
+    val df2 = Seq(("data", Array(TestData(50, "2")))).toDF("col1", "col2")
+    checkAnswer(df2.select(myUdf2(Column("col2"))), Row(100) :: Nil)
+  }
+
+  test("case class as key/value type of Map") {
+    val f1 = (s: Map[TestData, Int]) => s.keys.head.key * 
s.keys.head.value.toInt
+    val myUdf1 = udf(f1)
+    val df1 = Seq(("data", Map(TestData(50, "2") -> 502))).toDF("col1", "col2")
+    checkAnswer(df1.select(myUdf1(Column("col2"))), Row(100) :: Nil)
+
+    val f2 = (s: Map[Int, TestData]) => s.values.head.key * 
s.values.head.value.toInt
+    val myUdf2 = udf(f2)
+    val df2 = Seq(("data", Map(502 -> TestData(50, "2")))).toDF("col1", "col2")
+    checkAnswer(df2.select(myUdf2(Column("col2"))), Row(100) :: Nil)
+
+    val f3 = (s: Map[TestData, TestData]) => s.keys.head.key * 
s.values.head.value.toInt
+    val myUdf3 = udf(f3)
+    val df3 = Seq(("data", Map(TestData(50, "2") -> TestData(50, 
"2")))).toDF("col1", "col2")
+    checkAnswer(df3.select(myUdf3(Column("col2"))), Row(100) :: Nil)
+  }
+
+  test("case class as element of tuple") {
+    val f = (s: (TestData, Int)) => s._1.key * s._2
+    val myUdf = udf(f)
+    val df = Seq(("data", (TestData(50, "2"), 2))).toDF("col1", "col2")
+    checkAnswer(df.select(myUdf(Column("col2"))), Row(100) :: Nil)
+  }
+
+  test("case class as generic type of Option") {
+    val f = (o: Option[TestData]) => o.map(t => t.key * t.value.toInt)
+    val myUdf = udf(f)
+    val df1 = Seq(("data", Some(TestData(50, "2")))).toDF("col1", "col2")
+    checkAnswer(df1.select(myUdf(Column("col2"))), Row(100) :: Nil)
+    val df2 = Seq(("data", None: Option[TestData])).toDF("col1", "col2")
+    checkAnswer(df2.select(myUdf(Column("col2"))), Row(null) :: Nil)
+  }
+
+  test("more input fields than expect for case class") {
+    val f = (t: TestData2) => t.a * t.b
+    val myUdf = udf(f)
+    val df = spark.range(1)
+      .select(lit(50).as("a"), lit(2).as("b"), lit(2).as("c"))
+      .select(struct("a", "b", "c").as("col"))
+    checkAnswer(df.select(myUdf(Column("col"))), Row(100) :: Nil)
+  }
+
+  test("less input fields than expect for case class") {
+    val f = (t: TestData2) => t.a * t.b
+    val myUdf = udf(f)
+    val df = spark.range(1)
+      .select(lit(50).as("a"))
+      .select(struct("a").as("col"))
+    val error = intercept[AnalysisException] (df.select(myUdf(Column("col"))))

Review comment:
       nit: no space between `intercept[AnalysisException]` and `(df.select...`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to