Github user icexelloss commented on a diff in the pull request:
https://github.com/apache/spark/pull/21650#discussion_r199203674
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
---
@@ -94,36 +95,59 @@ object ExtractPythonUDFFromAggregate extends
Rule[LogicalPlan] {
*/
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
- private def hasPythonUDF(e: Expression): Boolean = {
+ private def hasScalarPythonUDF(e: Expression): Boolean = {
e.find(PythonUDF.isScalarPythonUDF).isDefined
}
- private def canEvaluateInPython(e: PythonUDF): Boolean = {
- e.children match {
- // single PythonUDF child could be chained and evaluated in Python
- case Seq(u: PythonUDF) => canEvaluateInPython(u)
- // Python UDF can't be evaluated directly in JVM
- case children => !children.exists(hasPythonUDF)
+ private def canEvaluateInPython(e: PythonUDF, evalType: Int): Boolean = {
+ if (e.evalType != evalType) {
+ false
+ } else {
+ e.children match {
+ // single PythonUDF child could be chained and evaluated in Python
+ case Seq(u: PythonUDF) => canEvaluateInPython(u, evalType)
+ // Python UDF can't be evaluated directly in JVM
+ case children => !children.exists(hasScalarPythonUDF)
+ }
}
}
- private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] =
expr match {
- case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) &&
canEvaluateInPython(udf) => Seq(udf)
- case e => e.children.flatMap(collectEvaluatableUDF)
+ private def collectEvaluableUDF(expr: Expression, evalType: Int):
Seq[PythonUDF] = expr match {
+ case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) &&
canEvaluateInPython(udf, evalType) =>
+ Seq(udf)
+ case e => e.children.flatMap(collectEvaluableUDF(_, evalType))
+ }
+
+ /**
+ * Collect evaluable UDFs from the current node.
+ *
+ * This function collects Python UDFs or Scalar Python UDFs from
expressions of the input node,
+ * and returns a list of UDFs of the same eval type.
--- End diff --
I tried this on master and got the same exception:
```
>>> foo = pandas_udf(lambda x: x, 'v int', PandasUDFType.GROUPED_MAP)
>>> df.select(foo(df['v'])).show()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File
"/Users/icexelloss/workspace/upstream/spark/python/pyspark/sql/dataframe.py",
line 353, in show
print(self._jdf.showString(n, 20, vertical))
File
"/Users/icexelloss/workspace/upstream/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py",
line 1257, in __call__
File
"/Users/icexelloss/workspace/upstream/spark/python/pyspark/sql/utils.py", line
63, in deco
return f(*a, **kw)
File
"/Users/icexelloss/workspace/upstream/spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py",
line 328, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling
o257.showString.
: java.lang.UnsupportedOperationException: Cannot evaluate expression:
<lambda>(input[0, bigint, false])
at
org.apache.spark.sql.catalyst.expressions.Unevaluable$class.doGenCode(Expression.scala:261)
at
org.apache.spark.sql.catalyst.expressions.PythonUDF.doGenCode(PythonUDF.scala:50)
at
org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:108)
at
org.apache.spark.sql.catalyst.expressions.Expression$$anonfun$genCode$2.apply(Expression.scala:105)
at scala.Option.getOrElse(Option.scala:121)
...
```
Therefore, this PR doesn't change that behavior. Both master and this PR
don't extract non-scalar UDF in the expression.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]