harshmotw-db commented on code in PR #48770:
URL: https://github.com/apache/spark/pull/48770#discussion_r1831622208
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala:
##########
@@ -169,6 +148,23 @@ case class PythonUDAF(
override protected def withNewChildrenInternal(newChildren:
IndexedSeq[Expression]): PythonUDAF =
copy(children = newChildren)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val check = super.checkInputDataTypes()
+ if (check.isFailure) {
+ check
+ } else {
+ val exprReturningVariant = children.collectFirst {
+ case e: Expression if
VariantExpressionEvalUtils.typeContainsVariant(e.dataType) => e
+ }
+ exprReturningVariant match {
+ case Some(e) => TypeCheckResult.DataTypeMismatch(
+ errorSubClass = "UNSUPPORTED_UDF_INPUT_TYPE",
Review Comment:
It was using the same error class before this PR as PythonUDAF inherited
from PythonFuncExpression from which I have moved this code block. I did this
so Variant is supported in other UDF types but not UDAFs.
Now, talking specifically about UDAFs, I found an [unofficial
blog](https://danvatterott.com/blog/2018/09/06/python-aggregate-udfs-in-pyspark/)
to implement PySpark UDFs, so I'll have to look further into it. If it works,
I might just remove this code block.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]