HyukjinKwon commented on a change in pull request #35615:
URL: https://github.com/apache/spark/pull/35615#discussion_r812477288



##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
##########
@@ -333,6 +367,56 @@ object IntegratedUDFTestUtils extends SQLHelper {
     val prettyName: String = "Scalar Pandas UDF"
   }
 
+  /**
+   * A Grouped Aggregate Pandas UDF that takes one column, executes the
+   * Python native function calculating the count of the column using pandas.
+   *
+   * Virtually equivalent to:
+   *
+   * {{{
+   *   import pandas as pd
+   *   from pyspark.sql.functions import pandas_udf
+   *
+   *   df = spark.createDataFrame(
+   *       [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
+   *
+   *   @pandas_udf("double")
+   *   def pandas_count(v: pd.Series) -> int:
+   *       return v.count()
+   *   count_col = pandas_count(df['v'])
+   * }}}
+   */
+  case class TestGroupedAggPandasUDF(name: String) extends TestUDF {
+    private[IntegratedUDFTestUtils] lazy val udf = new 
UserDefinedPythonFunction(
+      name = name,
+      func = PythonFunction(
+        command = pandasGroupedAggFunc,
+        envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, 
String]],
+        pythonIncludes = List.empty[String].asJava,
+        pythonExec = pythonExec,
+        pythonVer = pythonVer,
+        broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
+        accumulator = null),
+      dataType = IntegerType,
+      pythonEvalType = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+      udfDeterministic = true) {
+
+      override def builder(e: Seq[Expression]): Expression = {
+        assert(e.length == 1, "Defined UDF only has one column")
+        val expr = e.head
+        assert(expr.resolved, "column should be resolved to use the same type 
" +
+          "as input. Try df(name) or df.col(name)")
+        val pythonUDF = new PythonUDFWithoutId(
+          super.builder(Cast(expr, IntegerType) :: 
Nil).asInstanceOf[PythonUDF])
+        Cast(pythonUDF, expr.dataType)

Review comment:
       I think we don't need a cast here because the UDF always returns 
integer. I think we could just remove `builder`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to