richardc-db commented on code in PR #47253:
URL: https://github.com/apache/spark/pull/47253#discussion_r1673073035
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala:
##########
@@ -63,6 +63,32 @@ trait PythonFuncExpression extends NonSQLExpression with
UserDefinedExpression {
override def toString: String = s"$name(${children.mkString(",
")})#${resultId.id}$typeSuffix"
override def nullable: Boolean = true
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val check = super.checkInputDataTypes()
+ if (check.isFailure) {
+ check
+ } else {
+ val exprReturningVariant = children.collectFirst {
+ case e: Expression if typeContainsVariant(e.dataType) => e
+ }
+ exprReturningVariant match {
+ case Some(e) => TypeCheckResult.DataTypeMismatch(
+ errorSubClass = "UNSUPPORTED_UDF_INPUT_TYPE",
+ messageParameters = Map("dataType" -> s"\"${e.dataType.sql}\""))
+ case None => TypeCheckResult.TypeCheckSuccess
+ }
+ }
+ }
+
+ def typeContainsVariant(dt: DataType): Boolean = dt match {
Review Comment:
I moved it to `VariantExpressionEvalUtils`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]