Github user icexelloss commented on a diff in the pull request:
https://github.com/apache/spark/pull/21650#discussion_r205872386
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
---
@@ -94,36 +95,52 @@ object ExtractPythonUDFFromAggregate extends
Rule[LogicalPlan] {
*/
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
- private def hasPythonUDF(e: Expression): Boolean = {
+ private type EvalType = Int
+ private type EvalTypeChecker = EvalType => Boolean
+
+ private def hasScalarPythonUDF(e: Expression): Boolean = {
e.find(PythonUDF.isScalarPythonUDF).isDefined
}
private def canEvaluateInPython(e: PythonUDF): Boolean = {
e.children match {
// single PythonUDF child could be chained and evaluated in Python
- case Seq(u: PythonUDF) => canEvaluateInPython(u)
+ case Seq(u: PythonUDF) => e.evalType == u.evalType &&
canEvaluateInPython(u)
// Python UDF can't be evaluated directly in JVM
- case children => !children.exists(hasPythonUDF)
+ case children => !children.exists(hasScalarPythonUDF)
}
}
- private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] =
expr match {
- case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) &&
canEvaluateInPython(udf) => Seq(udf)
- case e => e.children.flatMap(collectEvaluatableUDF)
+ private def collectEvaluableUDFsFromExpressions(expressions:
Seq[Expression]): Seq[PythonUDF] = {
+ // Eval type checker is set once when we find the first evaluable UDF
and its value
+ // shouldn't change later.
+ // Used to check if subsequent UDFs are of the same type as the first
UDF. (since we can only
+ // extract UDFs of the same eval type)
+ var evalTypeChecker: Option[EvalTypeChecker] = None
+
+ def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr
match {
+ case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) &&
canEvaluateInPython(udf)
+ && evalTypeChecker.isEmpty =>
+ evalTypeChecker = Some((otherEvalType: EvalType) => otherEvalType
== udf.evalType)
+ Seq(udf)
--- End diff --
@HyukjinKwon In your code this line is `collectEvaluableUDFs (udf)`. I
think we should just return `Seq(udf)` to avoid checking the expression twice.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]