This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5c31ba4742e [SPARK-41907][CONNECT][PYTHON][TESTS] Update and enable
test `test_sampleby`
5c31ba4742e is described below
commit 5c31ba4742e98d6a33ff36ee82dcc9a605c86538
Author: Jiaan Geng <[email protected]>
AuthorDate: Tue Jan 10 19:49:34 2023 +0800
[SPARK-41907][CONNECT][PYTHON][TESTS] Update and enable test `test_sampleby`
### What changes were proposed in this pull request?
The `test_functions.py` have one test case for `stat.sampleBy`.
```
df = self.spark.createDataFrame([Row(a=i, b=(i % 3)) for i in range(100)])
sampled = df.stat.sampleBy("b", fractions={0: 0.5, 1: 0.5}, seed=0)
self.assertTrue(sampled.count() == 35)
```
Connect's py API cannot passed the tests.
```
Traceback (most recent call last):
File
"/Users/s.singh/personal/spark-oss/python/pyspark/sql/tests/test_functions.py",
line 202, in test_sampleby
self.assertTrue(sampled.count() == 35)
AssertionError: False is not true
```
After my investigation, the root cause is the plan is different from
pyspark, so the result is not determined.
The plan come from pyspark show below.
```
== Physical Plan ==
* Filter (2)
+- * Scan ExistingRDD (1)
(1) Scan ExistingRDD [codegen id : 1]
Output [2]: [a#4L, b#5L]
Arguments: [a#4L, b#5L], MapPartitionsRDD[9] at applySchemaToPythonRDD at
NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)
(2) Filter [codegen id : 1]
Input [2]: [a#4L, b#5L]
Condition : UDF(b#5L, rand(0))
```
The plan come from connect show below.
```
== Physical Plan ==
LocalTableScan (1)
(1) LocalTableScan
Output [2]: [a#5L, b#6L]
Arguments: [a#5L, b#6L]
```
### Why are the changes needed?
The issue is not related to `stat.sampleBy` directly.
This PR just let the code follows pyspark API and update the comment about
skip test.
### Does this PR introduce _any_ user-facing change?
'No'.
New feature.
### How was this patch tested?
N/A
Closes #39476 from beliefer/SPARK-41907.
Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/dataframe.py | 7 +++++--
python/pyspark/sql/tests/connect/test_parity_functions.py | 5 -----
python/pyspark/sql/tests/test_functions.py | 2 +-
3 files changed, 6 insertions(+), 8 deletions(-)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 5ff3d59ddd6..6e4d9c5a2db 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1152,7 +1152,11 @@ class DataFrame:
def sampleBy(
self, col: "ColumnOrName", fractions: Dict[Any, float], seed:
Optional[int] = None
) -> "DataFrame":
- if not isinstance(col, (Column, str)):
+ from pyspark.sql.connect.expressions import ColumnReference
+
+ if isinstance(col, str):
+ col = Column(ColumnReference(name=col))
+ elif not isinstance(col, Column):
raise TypeError("col must be a string or a column, but got %r" %
type(col))
if not isinstance(fractions, dict):
raise TypeError("fractions must be a dict but got %r" %
type(fractions))
@@ -1161,7 +1165,6 @@ class DataFrame:
raise TypeError("key must be float, int, or string, but got
%r" % type(k))
fractions[k] = float(v)
seed = seed if seed is not None else random.randint(0, sys.maxsize)
-
return DataFrame.withPlan(
plan.StatSampleBy(child=self._plan, col=col, fractions=fractions,
seed=seed),
session=self._session,
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py
b/python/pyspark/sql/tests/connect/test_parity_functions.py
index 65e1eb31fee..2f6ed05559f 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -110,11 +110,6 @@ class FunctionsParityTests(FunctionsTestsMixin,
ReusedConnectTestCase):
def test_sorting_functions_with_column(self):
super().test_sorting_functions_with_column()
- # TODO(SPARK-41907): sampleby returning wrong output
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_sampleby(self):
- super().test_sampleby()
-
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/sql/tests/test_functions.py
b/python/pyspark/sql/tests/test_functions.py
index 0fe13279200..4db1eed1eb1 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -200,7 +200,7 @@ class FunctionsTestsMixin:
def test_sampleby(self):
df = self.spark.createDataFrame([Row(a=i, b=(i % 3)) for i in
range(100)])
sampled = df.stat.sampleBy("b", fractions={0: 0.5, 1: 0.5}, seed=0)
- self.assertTrue(sampled.count() == 35)
+ self.assertTrue(35 <= sampled.count() <= 36)
def test_cov(self):
df = self.spark.createDataFrame([Row(a=i, b=2 * i) for i in range(10)])
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]