Repository: spark Updated Branches: refs/heads/master 8d33e1e5b -> 8880fd13e
[SPARK-14761][SQL] Reject invalid join methods when join columns are not specified in PySpark DataFrame join. ## What changes were proposed in this pull request? In PySpark, the invalid join type will not throw error for the following join: ```df1.join(df2, how='not-a-valid-join-type')``` The signature of the join is: ```def join(self, other, on=None, how=None):``` The existing code completely ignores the `how` parameter when `on` is `None`. This patch will process the arguments passed to join and pass in to JVM Spark SQL Analyzer, which will validate the join type passed. ## How was this patch tested? Used manual and existing test suites. Author: Bijay Pathak <bkpat...@mtu.edu> Closes #15409 from bkpathak/SPARK-14761. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8880fd13 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8880fd13 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8880fd13 Branch: refs/heads/master Commit: 8880fd13ef2b581f9c7190e7e3e6d24bc11b4ef7 Parents: 8d33e1e Author: Bijay Pathak <bkpat...@mtu.edu> Authored: Wed Oct 12 10:09:49 2016 -0700 Committer: Reynold Xin <r...@databricks.com> Committed: Wed Oct 12 10:09:49 2016 -0700 ---------------------------------------------------------------------- python/pyspark/sql/dataframe.py | 31 +++++++++++++++---------------- python/pyspark/sql/tests.py | 6 ++++++ 2 files changed, 21 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8880fd13/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 14e80ea..ce277eb 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -661,25 +661,24 @@ class DataFrame(object): if on is not None and not isinstance(on, list): on = [on] - if on is None or len(on) == 0: - jdf = self._jdf.crossJoin(other._jdf) - elif isinstance(on[0], basestring): - if how is None: - jdf = self._jdf.join(other._jdf, self._jseq(on), "inner") + if on is not None: + if isinstance(on[0], basestring): + on = self._jseq(on) else: - assert isinstance(how, basestring), "how should be basestring" - jdf = self._jdf.join(other._jdf, self._jseq(on), how) + assert isinstance(on[0], Column), "on should be Column or list of Column" + if len(on) > 1: + on = reduce(lambda x, y: x.__and__(y), on) + else: + on = on[0] + on = on._jc + + if on is None and how is None: + jdf = self._jdf.crossJoin(other._jdf) else: - assert isinstance(on[0], Column), "on should be Column or list of Column" - if len(on) > 1: - on = reduce(lambda x, y: x.__and__(y), on) - else: - on = on[0] if how is None: - jdf = self._jdf.join(other._jdf, on._jc, "inner") - else: - assert isinstance(how, basestring), "how should be basestring" - jdf = self._jdf.join(other._jdf, on._jc, how) + how = "inner" + assert isinstance(how, basestring), "how should be basestring" + jdf = self._jdf.join(other._jdf, on, how) return DataFrame(jdf, self.sql_ctx) @since(1.6) http://git-wip-us.apache.org/repos/asf/spark/blob/8880fd13/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 86c590d..61674a8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1508,6 +1508,12 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(df.schema.simpleString(), "struct<value:int>") self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) + # Regression test for invalid join methods when on is None, Spark-14761 + def test_invalid_join_method(self): + df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"]) + df2 = self.spark.createDataFrame([("Alice", 80), ("Bob", 90)], ["name", "height"]) + self.assertRaises(IllegalArgumentException, lambda: df1.join(df2, how="invalid-join-type")) + def test_conf(self): spark = self.spark spark.conf.set("bogo", "sipeo") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org