This is an automated email from the ASF dual-hosted git repository.
cutlerb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7858e53 [SPARK-28323][SQL][PYTHON] PythonUDF should be able to use in
join condition
7858e53 is described below
commit 7858e534d3195d532874a3d90121353895ba3f42
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Jul 10 16:29:58 2019 -0700
[SPARK-28323][SQL][PYTHON] PythonUDF should be able to use in join condition
## What changes were proposed in this pull request?
There is a bug in `ExtractPythonUDFs` that produces wrong result
attributes. It causes a failure when using `PythonUDF`s among multiple child
plans, e.g., join. An example is using `PythonUDF`s in join condition.
```python
>>> left = spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2,
a2=2)])
>>> right = spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3,
b2=1)])
>>> f = udf(lambda a: a, IntegerType())
>>> df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
>>> df.collect()
19/07/10 12:20:49 ERROR Executor: Exception in task 5.0 in stage 0.0 (TID 5)
java.lang.ArrayIndexOutOfBoundsException: 1
at
org.apache.spark.sql.catalyst.expressions.GenericInternalRow.genericGet(rows.scala:201)
at
org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.getAs(rows.scala:35)
at
org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.isNullAt(rows.scala:36)
at
org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.isNullAt$(rows.scala:36)
at
org.apache.spark.sql.catalyst.expressions.GenericInternalRow.isNullAt(rows.scala:195)
at
org.apache.spark.sql.catalyst.expressions.JoinedRow.isNullAt(JoinedRow.scala:70)
...
```
## How was this patch tested?
Added test.
Closes #25091 from viirya/SPARK-28323.
Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Bryan Cutler <[email protected]>
---
python/pyspark/sql/tests/test_udf.py | 10 +++++++++
.../sql/execution/python/ExtractPythonUDFs.scala | 2 +-
.../scala/org/apache/spark/sql/JoinSuite.scala | 25 ++++++++++++++++++++++
3 files changed, 36 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/tests/test_udf.py
b/python/pyspark/sql/tests/test_udf.py
index 0dafa18..803d471 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -197,6 +197,8 @@ class UDFTests(ReusedSQLTestCase):
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
+ # The udf uses attributes from both sides of join, so it is pulled out
as Filter +
+ # Cross join.
df = left.join(right, f("a", "b"))
with self.assertRaisesRegexp(AnalysisException, 'Detected implicit
cartesian product'):
df.collect()
@@ -243,6 +245,14 @@ class UDFTests(ReusedSQLTestCase):
runWithJoinType("leftanti", "LeftAnti")
runWithJoinType("leftsemi", "LeftSemi")
+ def test_udf_as_join_condition(self):
+ left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2,
a1=2, a2=2)])
+ right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1,
b1=3, b2=1)])
+ f = udf(lambda a: a, IntegerType())
+
+ df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
+ self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
+
def test_udf_without_arguments(self):
self.spark.catalog.registerFunction("foo", lambda: "bar")
[row] = self.spark.sql("SELECT foo()").collect()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 58fe7d5..fc4ded3 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -179,7 +179,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with
PredicateHelper {
validUdfs.forall(PythonUDF.isScalarPythonUDF),
"Can only extract scalar vectorized udf or sql batch udf")
- val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
+ val resultAttrs = validUdfs.zipWithIndex.map { case (u, i) =>
AttributeReference(s"pythonUDF$i", u.dataType)()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 38c634e..32cddc9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -28,6 +28,7 @@ import
org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
import org.apache.spark.sql.execution.{BinaryExecNode, SortExec}
import org.apache.spark.sql.execution.joins._
+import org.apache.spark.sql.execution.python.BatchEvalPythonExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType
@@ -969,4 +970,28 @@ class JoinSuite extends QueryTest with SharedSQLContext {
Seq(Row(0.0d, 0.0/0.0)))))
}
}
+
+ test("SPARK-28323: PythonUDF should be able to use in join condition") {
+ import IntegratedUDFTestUtils._
+
+ assume(shouldTestPythonUDFs)
+
+ val pythonTestUDF = TestPythonUDF(name = "udf")
+
+ val left = Seq((1, 2), (2, 3)).toDF("a", "b")
+ val right = Seq((1, 2), (3, 4)).toDF("c", "d")
+ val df = left.join(right, pythonTestUDF($"a") === pythonTestUDF($"c"))
+
+ val joinNode =
df.queryExecution.executedPlan.find(_.isInstanceOf[BroadcastHashJoinExec])
+ assert(joinNode.isDefined)
+
+ // There are two PythonUDFs which use attribute from left and right of
join, individually.
+ // So two PythonUDFs should be evaluated before the join operator, at left
and right side.
+ val pythonEvals = joinNode.get.collect {
+ case p: BatchEvalPythonExec => p
+ }
+ assert(pythonEvals.size == 2)
+
+ checkAnswer(df, Row(1, 2, 1, 2) :: Nil)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]