Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/21650#discussion_r205819781
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
---
@@ -94,36 +95,61 @@ object ExtractPythonUDFFromAggregate extends
Rule[LogicalPlan] {
*/
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
- private def hasPythonUDF(e: Expression): Boolean = {
+ private case class EvalTypeHolder(private var evalType: Int = -1) {
--- End diff --
How about this:
```scala
private type EvalType = Int
private type EvalTypeChecker = Option[EvalType => Boolean]
private def collectEvaluableUDFsFromExpressions(expressions:
Seq[Expression]): Seq[PythonUDF] = {
// Eval type checker is set in the middle of checking because once it's
found,
// the same eval type should be checked .. blah blah
var evalChecker: EvalTypeChecker = None
def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match
{
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) &&
canEvaluateInPython(udf)
&& evalChecker.isEmpty =>
evalChecker = Some((otherEvalType: EvalType) => otherEvalType ==
udf.evalType)
collectEvaluableUDFs(expr)
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) &&
canEvaluateInPython(udf)
&& evalChecker.get(udf.evalType) =>
Seq(udf)
case e => e.children.flatMap(collectEvaluableUDFs)
}
expressions.flatMap(collectEvaluableUDFs)
}
def apply(plan: SparkPlan): SparkPlan = plan transformUp {
case plan: SparkPlan => extract(plan)
}
/**
* Extract all the PythonUDFs from the current operator and evaluate them
before the operator.
*/
private def extract(plan: SparkPlan): SparkPlan = {
val udfs = collectEvaluableUDFsFromExpressions(plan.expressions)
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]