harshmotw-db commented on code in PR #48770:
URL: https://github.com/apache/spark/pull/48770#discussion_r1831622208


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala:
##########
@@ -169,6 +148,23 @@ case class PythonUDAF(
 
   override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): PythonUDAF =
     copy(children = newChildren)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else {
+      val exprReturningVariant = children.collectFirst {
+        case e: Expression if 
VariantExpressionEvalUtils.typeContainsVariant(e.dataType) => e
+      }
+      exprReturningVariant match {
+        case Some(e) => TypeCheckResult.DataTypeMismatch(
+          errorSubClass = "UNSUPPORTED_UDF_INPUT_TYPE",

Review Comment:
   It was using the same error class before this PR as PythonUDAF inherited 
from PythonFuncExpression from which I have moved this code block. I did this 
so Variant is supported in other UDF types but not UDAFs.
   
   Now, talking specifically about UDAFs, I found an [unofficial 
blog](https://danvatterott.com/blog/2018/09/06/python-aggregate-udfs-in-pyspark/)
 to implement PySpark UDFs, so I'll have to look further into it. If it works, 
I might just remove this code block.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to