Github user xuanyuanking commented on a diff in the pull request:
https://github.com/apache/spark/pull/22326#discussion_r216127932
--- Diff: python/pyspark/sql/tests.py ---
@@ -547,6 +547,74 @@ def test_udf_in_filter_on_top_of_join(self):
df = left.crossJoin(right).filter(f("a", "b"))
self.assertEqual(df.collect(), [Row(a=1, b=1)])
+ def test_udf_in_join_condition(self):
+ # regression test for SPARK-25314
+ from pyspark.sql.functions import udf
+ left = self.spark.createDataFrame([Row(a=1)])
+ right = self.spark.createDataFrame([Row(b=1)])
+ f = udf(lambda a, b: a == b, BooleanType())
+ df = left.join(right, f("a", "b"))
+ with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+ self.assertEqual(df.collect(), [Row(a=1, b=1)])
+
+ def test_udf_in_left_semi_join_condition(self):
+ # regression test for SPARK-25314
+ from pyspark.sql.functions import udf
+ left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2,
a1=2, a2=2)])
+ right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)])
+ f = udf(lambda a, b: a == b, BooleanType())
+ df = left.join(right, f("a", "b"), "leftsemi")
+ with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+ self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
+
+ def test_udf_and_filter_in_join_condition(self):
+ # regression test for SPARK-25314
+ # test the complex scenario with both udf(non-deterministic)
+ # and normal filter(deterministic)
+ from pyspark.sql.functions import udf
+ left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2,
a1=2, a2=2)])
+ right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=2,
b1=1, b2=2)])
+ f = udf(lambda a, b: a == b, BooleanType())
+ df = left.join(right, [f("a", "b1"), left.a == 1, right.b == 2])
+ with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+ self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=2,
b1=1, b2=2)])
+
+ def test_udf_and_filter_in_left_semi_join_condition(self):
+ # regression test for SPARK-25314
+ # test the complex scenario with both udf(non-deterministic)
+ # and normal filter(deterministic)
+ from pyspark.sql.functions import udf
+ left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2,
a1=2, a2=2)])
+ right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=2,
b1=1, b2=2)])
+ f = udf(lambda a, b: a == b, BooleanType())
+ df = left.join(right, [f("a", "b1"), left.a == 1, right.b == 2],
"left_semi")
+ with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+ self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
+
+ def test_udf_and_common_filter_in_join_condition(self):
--- End diff --
Add these two test for the comment in
https://github.com/apache/spark/pull/22326#discussion_r216127673.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]