Repository: spark
Updated Branches:
  refs/heads/master 3a3e65ada -> 422a45cf0


[SPARK-18766][SQL] Push Down Filter Through BatchEvalPython (Python UDF)

### What changes were proposed in this pull request?
Currently, when users use Python UDF in Filter, BatchEvalPython is always 
generated below FilterExec. However, not all the predicates need to be 
evaluated after Python UDF execution. Thus, this PR is to push down the 
determinisitc predicates through `BatchEvalPython`.
```Python
>>> df = spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], 
>>> ["key", "value"])
>>> from pyspark.sql.functions import udf, col
>>> from pyspark.sql.types import BooleanType
>>> my_filter = udf(lambda a: a < 2, BooleanType())
>>> sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & 
>>> (df.value < "2"))
>>> sel.explain(True)
```
Before the fix, the plan looks like
```
== Optimized Logical Plan ==
Filter ((isnotnull(value#1) && <lambda>(key#0L)) && (value#1 < 2))
+- LogicalRDD [key#0L, value#1]

== Physical Plan ==
*Project [key#0L, value#1]
+- *Filter ((isnotnull(value#1) && pythonUDF0#9) && (value#1 < 2))
   +- BatchEvalPython [<lambda>(key#0L)], [key#0L, value#1, pythonUDF0#9]
      +- Scan ExistingRDD[key#0L,value#1]
```

After the fix, the plan looks like
```
== Optimized Logical Plan ==
Filter ((isnotnull(value#1) && <lambda>(key#0L)) && (value#1 < 2))
+- LogicalRDD [key#0L, value#1]

== Physical Plan ==
*Project [key#0L, value#1]
+- *Filter pythonUDF0#9: boolean
   +- BatchEvalPython [<lambda>(key#0L)], [key#0L, value#1, pythonUDF0#9]
      +- *Filter (isnotnull(value#1) && (value#1 < 2))
         +- Scan ExistingRDD[key#0L,value#1]
```

### How was this patch tested?
Added both unit test cases for `BatchEvalPythonExec` and also add an end-to-end 
test case in Python test suite.

Author: gatorsmile <[email protected]>

Closes #16193 from gatorsmile/pythonUDFPredicatePushDown.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/422a45cf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/422a45cf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/422a45cf

Branch: refs/heads/master
Commit: 422a45cf0490cc354fa9348a2381a337d52c4f58
Parents: 3a3e65a
Author: gatorsmile <[email protected]>
Authored: Sat Dec 10 08:47:45 2016 -0800
Committer: gatorsmile <[email protected]>
Committed: Sat Dec 10 08:47:45 2016 -0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                     |   9 ++
 .../execution/python/ExtractPythonUDFs.scala    |  29 ++++-
 .../python/BatchEvalPythonExecSuite.scala       | 110 +++++++++++++++++++
 3 files changed, 143 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/422a45cf/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 66320bd..af7d52c 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -360,6 +360,15 @@ class SQLTests(ReusedPySparkTestCase):
         [res] = self.spark.sql("SELECT MYUDF('')").collect()
         self.assertEqual("", res[0])
 
+    def test_udf_with_filter_function(self):
+        df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, 
"2")], ["key", "value"])
+        from pyspark.sql.functions import udf, col
+        from pyspark.sql.types import BooleanType
+
+        my_filter = udf(lambda a: a < 2, BooleanType())
+        sel = df.select(col("key"), 
col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
+        self.assertEqual(sel.collect(), [Row(key=1, value='1')])
+
     def test_udf_with_aggregate_function(self):
         df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, 
"2")], ["key", "value"])
         from pyspark.sql.functions import udf, col, sum

http://git-wip-us.apache.org/repos/asf/spark/blob/422a45cf/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 16e4484..69b4b7b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -25,7 +25,7 @@ import 
org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, 
Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{FilterExec, SparkPlan}
 
 
 /**
@@ -90,7 +90,7 @@ object ExtractPythonUDFFromAggregate extends 
Rule[LogicalPlan] {
  * This has the limitation that the input to the Python UDF is not allowed 
include attributes from
  * multiple child operators.
  */
-object ExtractPythonUDFs extends Rule[SparkPlan] {
+object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
 
   private def hasPythonUDF(e: Expression): Boolean = {
     e.find(_.isInstanceOf[PythonUDF]).isDefined
@@ -126,10 +126,11 @@ object ExtractPythonUDFs extends Rule[SparkPlan] {
       plan
     } else {
       val attributeMap = mutable.HashMap[PythonUDF, Expression]()
+      val splitFilter = trySplitFilter(plan)
       // Rewrite the child that has the input required for the UDF
-      val newChildren = plan.children.map { child =>
+      val newChildren = splitFilter.children.map { child =>
         // Pick the UDF we are going to evaluate
-        val validUdfs = udfs.filter { case udf =>
+        val validUdfs = udfs.filter { udf =>
           // Check to make sure that the UDF can be evaluated with only the 
input of this child.
           udf.references.subsetOf(child.outputSet)
         }.toArray  // Turn it into an array since iterators cannot be 
serialized in Scala 2.10
@@ -150,7 +151,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] {
         sys.error(s"Invalid PythonUDF $udf, requires attributes from more than 
one child.")
       }
 
-      val rewritten = plan.withNewChildren(newChildren).transformExpressions {
+      val rewritten = 
splitFilter.withNewChildren(newChildren).transformExpressions {
         case p: PythonUDF if attributeMap.contains(p) =>
           attributeMap(p)
       }
@@ -165,4 +166,22 @@ object ExtractPythonUDFs extends Rule[SparkPlan] {
       }
     }
   }
+
+  // Split the original FilterExec to two FilterExecs. Only push down the 
first few predicates
+  // that are all deterministic.
+  private def trySplitFilter(plan: SparkPlan): SparkPlan = {
+    plan match {
+      case filter: FilterExec =>
+        val (candidates, containingNonDeterministic) =
+          splitConjunctivePredicates(filter.condition).span(_.deterministic)
+        val (pushDown, rest) = candidates.partition(!hasPythonUDF(_))
+        if (pushDown.nonEmpty) {
+          val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
+          FilterExec((rest ++ containingNonDeterministic).reduceLeft(And), 
newChild)
+        } else {
+          filter
+        }
+      case o => o
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/422a45cf/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
new file mode 100644
index 0000000..81bea2f
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.api.python.PythonFunction
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, 
GreaterThan, In}
+import org.apache.spark.sql.execution.{FilterExec, InputAdapter, 
SparkPlanTest, WholeStageCodegenExec}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.BooleanType
+
+class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
+  import testImplicits.newProductEncoder
+  import testImplicits.localSeqToDatasetHolder
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    spark.udf.registerPython("dummyPythonUDF", new MyDummyPythonUDF)
+  }
+
+  override def afterAll(): Unit = {
+    spark.sessionState.functionRegistry.dropFunction("dummyPythonUDF")
+    super.afterAll()
+  }
+
+  test("Python UDF: push down deterministic FilterExec predicates") {
+    val df = Seq(("Hello", 4)).toDF("a", "b")
+      .where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)")
+    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
+      case f @ FilterExec(
+          And(_: AttributeReference, _: AttributeReference),
+          InputAdapter(_: BatchEvalPythonExec)) => f
+      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: 
In, _))) => b
+    }
+    assert(qualifiedPlanNodes.size == 2)
+  }
+
+  test("Nested Python UDF: push down deterministic FilterExec predicates") {
+    val df = Seq(("Hello", 4)).toDF("a", "b")
+      .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)")
+    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
+      case f @ FilterExec(_: AttributeReference, InputAdapter(_: 
BatchEvalPythonExec)) => f
+      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: 
In, _))) => b
+    }
+    assert(qualifiedPlanNodes.size == 2)
+  }
+
+  test("Python UDF: no push down on non-deterministic") {
+    val df = Seq(("Hello", 4)).toDF("a", "b")
+      .where("b > 4 and dummyPythonUDF(a) and rand() > 3")
+    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
+      case f @ FilterExec(
+          And(_: AttributeReference, _: GreaterThan),
+          InputAdapter(_: BatchEvalPythonExec)) => f
+      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) 
=> b
+    }
+    assert(qualifiedPlanNodes.size == 2)
+  }
+
+  test("Python UDF: no push down on predicates starting from the first 
non-deterministic") {
+    val df = Seq(("Hello", 4)).toDF("a", "b")
+      .where("dummyPythonUDF(a) and rand() > 3 and b > 4")
+    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
+      case f @ FilterExec(And(_: And, _: GreaterThan), InputAdapter(_: 
BatchEvalPythonExec)) => f
+    }
+    assert(qualifiedPlanNodes.size == 1)
+  }
+
+  test("Python UDF refers to the attributes from more than one child") {
+    val df = Seq(("Hello", 4)).toDF("a", "b")
+    val df2 = Seq(("Hello", 4)).toDF("c", "d")
+    val joinDF = df.join(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, 
c)")
+
+    val e = intercept[RuntimeException] {
+      joinDF.queryExecution.executedPlan
+    }.getMessage
+    assert(Seq("Invalid PythonUDF dummyUDF", "requires attributes from more 
than one child")
+      .forall(e.contains))
+  }
+}
+
+// This Python UDF is dummy and just for testing. Unable to execute.
+class DummyUDF extends PythonFunction(
+  command = Array[Byte](),
+  envVars = Map("" -> "").asJava,
+  pythonIncludes = ArrayBuffer("").asJava,
+  pythonExec = "",
+  pythonVer = "",
+  broadcastVars = null,
+  accumulator = null)
+
+class MyDummyPythonUDF
+  extends UserDefinedPythonFunction(name = "dummyUDF", func = new DummyUDF, 
dataType = BooleanType)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to