This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5c31ba4742e [SPARK-41907][CONNECT][PYTHON][TESTS] Update and enable 
test `test_sampleby`
5c31ba4742e is described below

commit 5c31ba4742e98d6a33ff36ee82dcc9a605c86538
Author: Jiaan Geng <[email protected]>
AuthorDate: Tue Jan 10 19:49:34 2023 +0800

    [SPARK-41907][CONNECT][PYTHON][TESTS] Update and enable test `test_sampleby`
    
    ### What changes were proposed in this pull request?
    The `test_functions.py` have one test case for `stat.sampleBy`.
    ```
    df = self.spark.createDataFrame([Row(a=i, b=(i % 3)) for i in range(100)])
    sampled = df.stat.sampleBy("b", fractions={0: 0.5, 1: 0.5}, seed=0)
    self.assertTrue(sampled.count() == 35)
    ```
    Connect's py API cannot passed the tests.
    ```
    Traceback (most recent call last):
      File 
"/Users/s.singh/personal/spark-oss/python/pyspark/sql/tests/test_functions.py", 
line 202, in test_sampleby
        self.assertTrue(sampled.count() == 35)
    AssertionError: False is not true
    ```
    
    After my investigation, the root cause is the plan is different from 
pyspark, so the result is not determined.
    The plan come from pyspark show below.
    ```
    == Physical Plan ==
    * Filter (2)
    +- * Scan ExistingRDD (1)
    
    (1) Scan ExistingRDD [codegen id : 1]
    Output [2]: [a#4L, b#5L]
    Arguments: [a#4L, b#5L], MapPartitionsRDD[9] at applySchemaToPythonRDD at 
NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)
    
    (2) Filter [codegen id : 1]
    Input [2]: [a#4L, b#5L]
    Condition : UDF(b#5L, rand(0))
    ```
    
    The plan come from connect show below.
    ```
    == Physical Plan ==
    LocalTableScan (1)
    
    (1) LocalTableScan
    Output [2]: [a#5L, b#6L]
    Arguments: [a#5L, b#6L]
    ```
    
    ### Why are the changes needed?
    The issue is not related to `stat.sampleBy` directly.
    This PR just let the code follows pyspark API and update the comment about 
skip test.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    N/A
    
    Closes #39476 from beliefer/SPARK-41907.
    
    Authored-by: Jiaan Geng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/dataframe.py                   | 7 +++++--
 python/pyspark/sql/tests/connect/test_parity_functions.py | 5 -----
 python/pyspark/sql/tests/test_functions.py                | 2 +-
 3 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 5ff3d59ddd6..6e4d9c5a2db 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1152,7 +1152,11 @@ class DataFrame:
     def sampleBy(
         self, col: "ColumnOrName", fractions: Dict[Any, float], seed: 
Optional[int] = None
     ) -> "DataFrame":
-        if not isinstance(col, (Column, str)):
+        from pyspark.sql.connect.expressions import ColumnReference
+
+        if isinstance(col, str):
+            col = Column(ColumnReference(name=col))
+        elif not isinstance(col, Column):
             raise TypeError("col must be a string or a column, but got %r" % 
type(col))
         if not isinstance(fractions, dict):
             raise TypeError("fractions must be a dict but got %r" % 
type(fractions))
@@ -1161,7 +1165,6 @@ class DataFrame:
                 raise TypeError("key must be float, int, or string, but got 
%r" % type(k))
             fractions[k] = float(v)
         seed = seed if seed is not None else random.randint(0, sys.maxsize)
-
         return DataFrame.withPlan(
             plan.StatSampleBy(child=self._plan, col=col, fractions=fractions, 
seed=seed),
             session=self._session,
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py 
b/python/pyspark/sql/tests/connect/test_parity_functions.py
index 65e1eb31fee..2f6ed05559f 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -110,11 +110,6 @@ class FunctionsParityTests(FunctionsTestsMixin, 
ReusedConnectTestCase):
     def test_sorting_functions_with_column(self):
         super().test_sorting_functions_with_column()
 
-    # TODO(SPARK-41907): sampleby returning wrong output
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_sampleby(self):
-        super().test_sampleby()
-
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index 0fe13279200..4db1eed1eb1 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -200,7 +200,7 @@ class FunctionsTestsMixin:
     def test_sampleby(self):
         df = self.spark.createDataFrame([Row(a=i, b=(i % 3)) for i in 
range(100)])
         sampled = df.stat.sampleBy("b", fractions={0: 0.5, 1: 0.5}, seed=0)
-        self.assertTrue(sampled.count() == 35)
+        self.assertTrue(35 <= sampled.count() <= 36)
 
     def test_cov(self):
         df = self.spark.createDataFrame([Row(a=i, b=2 * i) for i in range(10)])


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to