Github user BryanCutler commented on the issue:
https://github.com/apache/spark/pull/21650
I gave it a shot to extract the UDFs in one traversal, using the first
occurrence of either pandas or batch udf. I think it's much clearer
```scala
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
private class FirstEvalType() {
var evalType = -1
def isEvalTypeSet(): Boolean = evalType >= 0
}
private def canEvaluateInPython(e: PythonUDF, firstEvalType:
FirstEvalType): Boolean = {
if (firstEvalType.isEvalTypeSet() && e.evalType !=
firstEvalType.evalType) {
false
} else {
firstEvalType.evalType = e.evalType
e.children match {
// single PythonUDF child could be chained and evaluated in Python
case Seq(u: PythonUDF) => canEvaluateInPython(u, firstEvalType)
// Python UDF can't be evaluated directly in JVM
case children => !children.exists(hasScalarPythonUDF)
}
}
}
private def collectEvaluableUDFs(expr: Expression, firstEvalType:
FirstEvalType): Seq[PythonUDF] = expr match {
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) &&
canEvaluateInPython(udf, firstEvalType) =>
Seq(udf)
case e => e.children.flatMap(collectEvaluableUDFs(_, firstEvalType))
}
private def extract(plan: SparkPlan): SparkPlan = {
val udfs = plan.expressions.flatMap(collectEvaluableUDFs(_, new
FirstEvalType))
...
```
This does pass around a mutable object, but I guess you could do about the
same using an Option that gets returned, but that might not look as nice.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]