This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e821d84409f [SPARK-41977][SPARK-41978][CONNECT] SparkSession.range to
take float as arguments
e821d84409f is described below
commit e821d84409f00e03f9469c9e8e7040e9cc5a5d9f
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Thu Jan 12 00:23:51 2023 +0900
[SPARK-41977][SPARK-41978][CONNECT] SparkSession.range to take float as
arguments
### What changes were proposed in this pull request?
This PR proposes to Spark Connect's `SparkSession.range` to accept floats.
e.g., `spark.range(10e10)`.
### Why are the changes needed?
For feature parity.
### Does this PR introduce _any_ user-facing change?
No to end users since Spark Connect has not been released yet.
`SparkSession.range` allows floats.
### How was this patch tested?
Unittests enabled back.
Closes #39499 from HyukjinKwon/SPARK-41977.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/session.py | 8 +++++++-
.../pyspark/sql/tests/connect/test_parity_dataframe.py | 4 ----
python/pyspark/sql/tests/test_dataframe.py | 17 +++++++++--------
3 files changed, 16 insertions(+), 13 deletions(-)
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 4c5ea3da10e..618608d64f7 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -333,8 +333,14 @@ class SparkSession:
else:
actual_end = end
+ if numPartitions is not None:
+ numPartitions = int(numPartitions)
+
return DataFrame.withPlan(
- Range(start=start, end=actual_end, step=step,
num_partitions=numPartitions), self
+ Range(
+ start=int(start), end=int(actual_end), step=int(step),
num_partitions=numPartitions
+ ),
+ self,
)
range.__doc__ = PySparkSession.range.__doc__
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index 32fb6216ba9..5c3e4ee1a01 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -61,10 +61,6 @@ class DataFrameParityTests(DataFrameTestsMixin,
ReusedConnectTestCase):
def test_extended_hint_types(self):
super().test_extended_hint_types()
- @unittest.skip("Spark Connect does not support JVM function _jdf but the
tests depend on them")
- def test_generic_hints(self):
- super().test_generic_hints()
-
@unittest.skip("Spark Connect does not support RDD but the tests depend on
them.")
def test_help_command(self):
super().test_help_command()
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index 2a82b0ab90d..e83ecbf2e6e 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -23,6 +23,8 @@ import tempfile
import time
import unittest
from typing import cast
+import io
+from contextlib import redirect_stdout
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import col, lit, count, sum, mean, struct
@@ -534,20 +536,17 @@ class DataFrameTestsMixin:
self.assertRaises(Exception, self.df.withColumns)
def test_generic_hints(self):
- from pyspark.sql import DataFrame
-
df1 = self.spark.range(10e10).toDF("id")
df2 = self.spark.range(10e10).toDF("id")
- self.assertIsInstance(df1.hint("broadcast"), DataFrame)
- self.assertIsInstance(df1.hint("broadcast", []), DataFrame)
+ self.assertIsInstance(df1.hint("broadcast"), type(df1))
# Dummy rules
- self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame)
- self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame)
+ self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), type(df1))
- plan = df1.join(df2.hint("broadcast"),
"id")._jdf.queryExecution().executedPlan()
- self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))
+ with io.StringIO() as buf, redirect_stdout(buf):
+ df1.join(df2.hint("broadcast"), "id").explain(True)
+ self.assertEqual(1, buf.getvalue().count("BroadcastHashJoin"))
# add tests for SPARK-23647 (test more types for hint)
def test_extended_hint_types(self):
@@ -556,6 +555,8 @@ class DataFrameTestsMixin:
hinted_df = df.hint("my awesome hint", 1.2345, "what",
such_a_nice_list)
logical_plan = hinted_df._jdf.queryExecution().logical()
+ self.assertIsInstance(df.hint("broadcast", []), type(df))
+ self.assertIsInstance(df.hint("broadcast", ["foo", "bar"]), type(df))
self.assertEqual(1, logical_plan.toString().count("1.2345"))
self.assertEqual(1, logical_plan.toString().count("what"))
self.assertEqual(3, logical_plan.toString().count("itworks"))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]