Github user icexelloss commented on a diff in the pull request:
https://github.com/apache/spark/pull/21650#discussion_r205262719
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
---
@@ -94,36 +95,94 @@ object ExtractPythonUDFFromAggregate extends
Rule[LogicalPlan] {
*/
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
- private def hasPythonUDF(e: Expression): Boolean = {
+ private case class LazyEvalType(var evalType: Int = -1) {
+
+ def isSet: Boolean = evalType >= 0
+
+ def set(evalType: Int): Unit = {
+ if (isSet) {
+ throw new IllegalStateException("Eval type has already been set")
+ } else {
+ this.evalType = evalType
+ }
+ }
+
+ def get(): Int = {
+ if (!isSet) {
+ throw new IllegalStateException("Eval type is not set")
+ } else {
+ evalType
+ }
+ }
+ }
+
+ private def hasScalarPythonUDF(e: Expression): Boolean = {
e.find(PythonUDF.isScalarPythonUDF).isDefined
}
- private def canEvaluateInPython(e: PythonUDF): Boolean = {
- e.children match {
- // single PythonUDF child could be chained and evaluated in Python
- case Seq(u: PythonUDF) => canEvaluateInPython(u)
- // Python UDF can't be evaluated directly in JVM
- case children => !children.exists(hasPythonUDF)
+ /**
+ * Check whether a PythonUDF expression can be evaluated in Python.
+ *
+ * If the lazy eval type is not set, this method checks for either
Batched Python UDF and Scalar
+ * Pandas UDF. If the lazy eval type is set, this method checks for the
expression of the
+ * specified eval type.
+ *
+ * This method will also set the lazy eval type to be the type of the
first evaluable expression,
+ * i.e., if lazy eval type is not set and we find a evaluable Python UDF
expression, lazy eval
+ * type will be set to the eval type of the expression.
+ *
+ */
+ private def canEvaluateInPython(e: PythonUDF, lazyEvalType:
LazyEvalType): Boolean = {
--- End diff --
Bryan, I tried to apply your implementation and the simple test also fail:
```
@udf('int')
def f1(x):
assert type(x) == int
return x + 1
@pandas_udf('int')
def f2(x):
assert type(x) == pd.Series
return x + 10
df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v'])))
expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11)
self.assertEquals(expected_chained_1.collect(), df_chained_1.collect())
```
Do you mind trying this too? Hopefully I didn't do something silly here..
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]