This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e821d84409f [SPARK-41977][SPARK-41978][CONNECT] SparkSession.range to 
take float as arguments
e821d84409f is described below

commit e821d84409f00e03f9469c9e8e7040e9cc5a5d9f
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Thu Jan 12 00:23:51 2023 +0900

    [SPARK-41977][SPARK-41978][CONNECT] SparkSession.range to take float as 
arguments
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to Spark Connect's `SparkSession.range` to accept floats. 
e.g., `spark.range(10e10)`.
    
    ### Why are the changes needed?
    
    For feature parity.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No to end users since Spark Connect has not been released yet.
    `SparkSession.range` allows floats.
    
    ### How was this patch tested?
    
    Unittests enabled back.
    
    Closes #39499 from HyukjinKwon/SPARK-41977.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/session.py                   |  8 +++++++-
 .../pyspark/sql/tests/connect/test_parity_dataframe.py  |  4 ----
 python/pyspark/sql/tests/test_dataframe.py              | 17 +++++++++--------
 3 files changed, 16 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 4c5ea3da10e..618608d64f7 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -333,8 +333,14 @@ class SparkSession:
         else:
             actual_end = end
 
+        if numPartitions is not None:
+            numPartitions = int(numPartitions)
+
         return DataFrame.withPlan(
-            Range(start=start, end=actual_end, step=step, 
num_partitions=numPartitions), self
+            Range(
+                start=int(start), end=int(actual_end), step=int(step), 
num_partitions=numPartitions
+            ),
+            self,
         )
 
     range.__doc__ = PySparkSession.range.__doc__
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py 
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index 32fb6216ba9..5c3e4ee1a01 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -61,10 +61,6 @@ class DataFrameParityTests(DataFrameTestsMixin, 
ReusedConnectTestCase):
     def test_extended_hint_types(self):
         super().test_extended_hint_types()
 
-    @unittest.skip("Spark Connect does not support JVM function _jdf but the 
tests depend on them")
-    def test_generic_hints(self):
-        super().test_generic_hints()
-
     @unittest.skip("Spark Connect does not support RDD but the tests depend on 
them.")
     def test_help_command(self):
         super().test_help_command()
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 2a82b0ab90d..e83ecbf2e6e 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -23,6 +23,8 @@ import tempfile
 import time
 import unittest
 from typing import cast
+import io
+from contextlib import redirect_stdout
 
 from pyspark.sql import SparkSession, Row
 from pyspark.sql.functions import col, lit, count, sum, mean, struct
@@ -534,20 +536,17 @@ class DataFrameTestsMixin:
         self.assertRaises(Exception, self.df.withColumns)
 
     def test_generic_hints(self):
-        from pyspark.sql import DataFrame
-
         df1 = self.spark.range(10e10).toDF("id")
         df2 = self.spark.range(10e10).toDF("id")
 
-        self.assertIsInstance(df1.hint("broadcast"), DataFrame)
-        self.assertIsInstance(df1.hint("broadcast", []), DataFrame)
+        self.assertIsInstance(df1.hint("broadcast"), type(df1))
 
         # Dummy rules
-        self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame)
-        self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame)
+        self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), type(df1))
 
-        plan = df1.join(df2.hint("broadcast"), 
"id")._jdf.queryExecution().executedPlan()
-        self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))
+        with io.StringIO() as buf, redirect_stdout(buf):
+            df1.join(df2.hint("broadcast"), "id").explain(True)
+            self.assertEqual(1, buf.getvalue().count("BroadcastHashJoin"))
 
     # add tests for SPARK-23647 (test more types for hint)
     def test_extended_hint_types(self):
@@ -556,6 +555,8 @@ class DataFrameTestsMixin:
         hinted_df = df.hint("my awesome hint", 1.2345, "what", 
such_a_nice_list)
         logical_plan = hinted_df._jdf.queryExecution().logical()
 
+        self.assertIsInstance(df.hint("broadcast", []), type(df))
+        self.assertIsInstance(df.hint("broadcast", ["foo", "bar"]), type(df))
         self.assertEqual(1, logical_plan.toString().count("1.2345"))
         self.assertEqual(1, logical_plan.toString().count("what"))
         self.assertEqual(3, logical_plan.toString().count("itworks"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to