Github user BryanCutler commented on a diff in the pull request:
https://github.com/apache/spark/pull/21650#discussion_r202864194
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
---
@@ -97,6 +103,64 @@ class BatchEvalPythonExecSuite extends SparkPlanTest
with SharedSQLContext {
}
assert(qualifiedPlanNodes.size == 1)
}
+
+ private def collectPythonExec(plan: SparkPlan): Seq[BatchEvalPythonExec]
= plan.collect {
+ case b: BatchEvalPythonExec => b
+ }
+
+ private def collectPandasExec(plan: SparkPlan): Seq[ArrowEvalPythonExec]
= plan.collect {
+ case b: ArrowEvalPythonExec => b
+ }
+
+ test("Chained Python UDFs should be combined to a single physical node")
{
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ val df2 = df.withColumn("c", pythonUDF(col("a"))).withColumn("d",
pythonUDF(col("c")))
+ val pythonEvalNodes =
collectPythonExec(df2.queryExecution.executedPlan)
+ assert(pythonEvalNodes.size == 1)
+ }
+
+ test("Chained Pandas UDFs should be combined to a single physical node")
{
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ val df2 = df.withColumn("c", pandasUDF(col("a"))).withColumn("d",
pandasUDF(col("c")))
+ val arrowEvalNodes = collectPandasExec(df2.queryExecution.executedPlan)
+ assert(arrowEvalNodes.size == 1)
+ }
+
+ test("Mixed Python UDFs and Pandas UDF should be separate physical
node") {
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ val df2 = df.withColumn("c", pythonUDF(col("a"))).withColumn("d",
pandasUDF(col("b")))
+
+ val pythonEvalNodes =
collectPythonExec(df2.queryExecution.executedPlan)
+ val arrowEvalNodes = collectPandasExec(df2.queryExecution.executedPlan)
+ assert(pythonEvalNodes.size == 1)
+ assert(arrowEvalNodes.size == 1)
+ }
+
+ test("Independent Python UDFs and Pandas UDFs should be combined
separately") {
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ val df2 = df.withColumn("c1", pythonUDF(col("a")))
+ .withColumn("c2", pythonUDF(col("c1")))
+ .withColumn("d1", pandasUDF(col("a")))
+ .withColumn("d2", pandasUDF(col("d1")))
+
+ val pythonEvalNodes =
collectPythonExec(df2.queryExecution.executedPlan)
+ val arrowEvalNodes = collectPandasExec(df2.queryExecution.executedPlan)
+ assert(pythonEvalNodes.size == 1)
+ assert(arrowEvalNodes.size == 1)
+ }
+
+ test("Dependent Python UDFs and Pandas UDFs should not be combined") {
--- End diff --
"Dependent Python Batched..."
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]