WangGuangxin commented on code in PR #8669:
URL: https://github.com/apache/incubator-gluten/pull/8669#discussion_r1947537261
##########
gluten-ut/spark35/src/test/scala/org/apache/spark/sql/hive/execution/GlutenHiveUDFSuite.scala:
##########
@@ -97,43 +51,99 @@ class GlutenHiveUDFSuite
tableDF.createOrReplaceTempView(table)
}
- override protected def afterAll(): Unit = {
- try {
- hiveContext.reset()
- } finally {
- super.afterAll()
- }
- }
-
- override protected def shouldRun(testName: String): Boolean = {
- false
+ override def afterAll(): Unit = {
+ super.afterAll()
}
test("customer udf") {
- sql(s"CREATE TEMPORARY FUNCTION testUDF AS
'${classOf[CustomerUDF].getName}'")
- val df = spark.sql("""select testUDF(l_comment)
- | from lineitem""".stripMargin)
- df.show()
- print(df.queryExecution.executedPlan)
- sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
- hiveContext.reset()
+ withTempFunction("testUDF") {
+ sql(s"CREATE TEMPORARY FUNCTION testUDF AS
'${classOf[CustomerUDF].getName}'")
+ val df = sql("select l_partkey, testUDF(l_comment) from lineitem")
+ df.show()
+ checkOperatorMatch[ColumnarPartialProjectExec](df)
+ }
}
test("customer udf wrapped in function") {
- sql(s"CREATE TEMPORARY FUNCTION testUDF AS
'${classOf[CustomerUDF].getName}'")
- val df = spark.sql("""select hash(testUDF(l_comment))
- | from lineitem""".stripMargin)
- df.show()
- print(df.queryExecution.executedPlan)
- sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
- hiveContext.reset()
+ withTempFunction("testUDF") {
+ sql(s"CREATE TEMPORARY FUNCTION testUDF AS
'${classOf[CustomerUDF].getName}'")
+ val df = sql("select l_partkey, hash(testUDF(l_comment)) from lineitem")
+ df.show()
+ checkOperatorMatch[ColumnarPartialProjectExec](df)
+ }
}
test("example") {
- spark.sql("CREATE TEMPORARY FUNCTION testUDF AS
'org.apache.hadoop.hive.ql.udf.UDFSubstr';")
- spark.sql("select testUDF('l_commen', 1, 5)").show()
- sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
- hiveContext.reset()
+ withTempFunction("testUDF") {
+ sql("CREATE TEMPORARY FUNCTION testUDF AS
'org.apache.hadoop.hive.ql.udf.UDFSubstr';")
+ val df = sql("select testUDF('l_commen', 1, 5)")
+ df.show()
+ // It should not be converted to ColumnarPartialProjectExec, since
+ // the UDF need all the columns in child output.
+ assert(!getExecutedPlan(df).exists {
+ case _: ColumnarPartialProjectExec => true
+ case _ => false
+ })
+ }
+ }
+
+ test("udf with array") {
+ withTempFunction("udf_sort_array") {
+ sql("""
+ |CREATE TEMPORARY FUNCTION udf_sort_array AS
+ |'org.apache.hadoop.hive.ql.udf.generic.GenericUDFSortArray';
+ |""".stripMargin)
+
+ val df = sql("""
+ |SELECT
+ | l_orderkey,
+ | l_partkey,
+ | udf_sort_array(array(10, l_orderkey, 1)) as udf_result
+ |FROM lineitem WHERE l_partkey <= 5 and l_orderkey <1000
+ |""".stripMargin)
+
+ checkAnswer(
+ df,
+ Seq(
+ Row(35, 5, mutable.WrappedArray.make(Array(1, 10, 35))),
+ Row(321, 4, mutable.WrappedArray.make(Array(1, 10, 321))),
+ Row(548, 2, mutable.WrappedArray.make(Array(1, 10, 548))),
+ Row(640, 5, mutable.WrappedArray.make(Array(1, 10, 640))),
+ Row(807, 2, mutable.WrappedArray.make(Array(1, 10, 807)))
+ )
+ )
+ checkOperatorMatch[ColumnarPartialProjectExec](df)
+ }
}
+ test("udf with map") {
Review Comment:
@Yohahaha done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]