This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e821d84409f [SPARK-41977][SPARK-41978][CONNECT] SparkSession.range to take float as arguments e821d84409f is described below commit e821d84409f00e03f9469c9e8e7040e9cc5a5d9f Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Thu Jan 12 00:23:51 2023 +0900 [SPARK-41977][SPARK-41978][CONNECT] SparkSession.range to take float as arguments ### What changes were proposed in this pull request? This PR proposes to Spark Connect's `SparkSession.range` to accept floats. e.g., `spark.range(10e10)`. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? No to end users since Spark Connect has not been released yet. `SparkSession.range` allows floats. ### How was this patch tested? Unittests enabled back. Closes #39499 from HyukjinKwon/SPARK-41977. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/session.py | 8 +++++++- .../pyspark/sql/tests/connect/test_parity_dataframe.py | 4 ---- python/pyspark/sql/tests/test_dataframe.py | 17 +++++++++-------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 4c5ea3da10e..618608d64f7 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -333,8 +333,14 @@ class SparkSession: else: actual_end = end + if numPartitions is not None: + numPartitions = int(numPartitions) + return DataFrame.withPlan( - Range(start=start, end=actual_end, step=step, num_partitions=numPartitions), self + Range( + start=int(start), end=int(actual_end), step=int(step), num_partitions=numPartitions + ), + self, ) range.__doc__ = PySparkSession.range.__doc__ diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py index 32fb6216ba9..5c3e4ee1a01 100644 --- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py +++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py @@ -61,10 +61,6 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase): def test_extended_hint_types(self): super().test_extended_hint_types() - @unittest.skip("Spark Connect does not support JVM function _jdf but the tests depend on them") - def test_generic_hints(self): - super().test_generic_hints() - @unittest.skip("Spark Connect does not support RDD but the tests depend on them.") def test_help_command(self): super().test_help_command() diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 2a82b0ab90d..e83ecbf2e6e 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -23,6 +23,8 @@ import tempfile import time import unittest from typing import cast +import io +from contextlib import redirect_stdout from pyspark.sql import SparkSession, Row from pyspark.sql.functions import col, lit, count, sum, mean, struct @@ -534,20 +536,17 @@ class DataFrameTestsMixin: self.assertRaises(Exception, self.df.withColumns) def test_generic_hints(self): - from pyspark.sql import DataFrame - df1 = self.spark.range(10e10).toDF("id") df2 = self.spark.range(10e10).toDF("id") - self.assertIsInstance(df1.hint("broadcast"), DataFrame) - self.assertIsInstance(df1.hint("broadcast", []), DataFrame) + self.assertIsInstance(df1.hint("broadcast"), type(df1)) # Dummy rules - self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame) - self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame) + self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), type(df1)) - plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan() - self.assertEqual(1, plan.toString().count("BroadcastHashJoin")) + with io.StringIO() as buf, redirect_stdout(buf): + df1.join(df2.hint("broadcast"), "id").explain(True) + self.assertEqual(1, buf.getvalue().count("BroadcastHashJoin")) # add tests for SPARK-23647 (test more types for hint) def test_extended_hint_types(self): @@ -556,6 +555,8 @@ class DataFrameTestsMixin: hinted_df = df.hint("my awesome hint", 1.2345, "what", such_a_nice_list) logical_plan = hinted_df._jdf.queryExecution().logical() + self.assertIsInstance(df.hint("broadcast", []), type(df)) + self.assertIsInstance(df.hint("broadcast", ["foo", "bar"]), type(df)) self.assertEqual(1, logical_plan.toString().count("1.2345")) self.assertEqual(1, logical_plan.toString().count("what")) self.assertEqual(3, logical_plan.toString().count("itworks")) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org