This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ba39a6b4a60 [SPARK-41779][SPARK-41771][CONNECT][PYTHON] Make
`__getitem__` support filter and select
ba39a6b4a60 is described below
commit ba39a6b4a60883708a2ed6e7e5c00a8649ddc66f
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Dec 30 16:11:02 2022 +0800
[SPARK-41779][SPARK-41771][CONNECT][PYTHON] Make `__getitem__` support
filter and select
### What changes were proposed in this pull request?
Make dataframe `__getitem__` support:
1, filter: `cdf[cdf.a.isin(1, 2, 3)]`
2, select: `cdf[["col1", cdf.a]]`
3, index: `cdf[0]`
### Why are the changes needed?
to be consistent with
[PySpark](https://github.com/apache/spark/blob/master/python/pyspark/sql/dataframe.py#L2764-L2825)
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added UT
Closes #39300 from zhengruifeng/connect_df_getitem_filter.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/column.py | 2 -
python/pyspark/sql/connect/dataframe.py | 29 +++++++--
.../sql/tests/connect/test_connect_basic.py | 73 ++++++++++++++++++++++
3 files changed, 96 insertions(+), 8 deletions(-)
diff --git a/python/pyspark/sql/connect/column.py
b/python/pyspark/sql/connect/column.py
index 206d30b15d8..9be202145f2 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -458,8 +458,6 @@ def _test() -> None:
# the row
del pyspark.sql.connect.column.Column.isNotNull.__doc__
del pyspark.sql.connect.column.Column.isNull.__doc__
- del pyspark.sql.connect.column.Column.isin.__doc__
- # TODO(SPARK-41771): __getitem__ does not work with Column.isin
del pyspark.sql.connect.column.Column.getField.__doc__
del pyspark.sql.connect.column.Column.getItem.__doc__
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 5b5a6c3f4b5..256e63122ab 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -901,13 +901,30 @@ class DataFrame:
def __getattr__(self, name: str) -> "Column":
return self[name]
- def __getitem__(self, name: str) -> "Column":
- # Check for alias
- alias = self._get_alias()
- if alias is not None:
- return col(alias)
+ @overload
+ def __getitem__(self, item: Union[int, str]) -> Column:
+ ...
+
+ @overload
+ def __getitem__(self, item: Union[Column, List, Tuple]) -> "DataFrame":
+ ...
+
+ def __getitem__(self, item: Union[int, str, Column, List, Tuple]) ->
Union[Column, "DataFrame"]:
+ if isinstance(item, str):
+ # Check for alias
+ alias = self._get_alias()
+ if alias is not None:
+ return col(alias)
+ else:
+ return col(item)
+ elif isinstance(item, Column):
+ return self.filter(item)
+ elif isinstance(item, (list, tuple)):
+ return self.select(*item)
+ elif isinstance(item, int):
+ return col(self.columns[item])
else:
- return col(name)
+ raise TypeError("unexpected item type: %s" % type(item))
def _print_plan(self) -> str:
if self._plan:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 99ee54a87fa..9663f3123f9 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -117,6 +117,79 @@ class SparkConnectSQLTestCase(PandasOnSparkTestCase,
ReusedPySparkTestCase, SQLT
class SparkConnectTests(SparkConnectSQLTestCase):
+ def test_df_get_item(self):
+ # SPARK-41779: test __getitem__
+
+ query = """
+ SELECT * FROM VALUES
+ (true, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
+ AS tab(a, b, c)
+ """
+
+ # +-----+----+----+
+ # | a| b| c|
+ # +-----+----+----+
+ # | true| 1|null|
+ # |false|null| 2.0|
+ # | null| 3| 3.0|
+ # +-----+----+----+
+
+ cdf = self.connect.sql(query)
+ sdf = self.spark.sql(query)
+
+ # filter
+ self.assert_eq(
+ cdf[cdf.a].toPandas(),
+ sdf[sdf.a].toPandas(),
+ )
+ self.assert_eq(
+ cdf[cdf.b.isin(2, 3)].toPandas(),
+ sdf[sdf.b.isin(2, 3)].toPandas(),
+ )
+ self.assert_eq(
+ cdf[cdf.c > 1.5].toPandas(),
+ sdf[sdf.c > 1.5].toPandas(),
+ )
+
+ # select
+ self.assert_eq(
+ cdf[[cdf.a, "b", cdf.c]].toPandas(),
+ sdf[[sdf.a, "b", sdf.c]].toPandas(),
+ )
+ self.assert_eq(
+ cdf[(cdf.a, "b", cdf.c)].toPandas(),
+ sdf[(sdf.a, "b", sdf.c)].toPandas(),
+ )
+
+ # select by index
+ self.assertTrue(isinstance(cdf[0], Column))
+ self.assertTrue(isinstance(cdf[1], Column))
+ self.assertTrue(isinstance(cdf[2], Column))
+
+ self.assert_eq(
+ cdf[[cdf[0], cdf[1], cdf[2]]].toPandas(),
+ sdf[[sdf[0], sdf[1], sdf[2]]].toPandas(),
+ )
+
+ # check error
+ with self.assertRaisesRegex(
+ TypeError,
+ "unexpected item type",
+ ):
+ cdf[1.5]
+
+ with self.assertRaisesRegex(
+ TypeError,
+ "unexpected item type",
+ ):
+ cdf[None]
+
+ with self.assertRaisesRegex(
+ TypeError,
+ "unexpected item type",
+ ):
+ cdf[cdf]
+
def test_error_handling(self):
# SPARK-41533 Proper error handling for Spark Connect
df = self.connect.range(10).select("id2")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]