This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ba39a6b4a60 [SPARK-41779][SPARK-41771][CONNECT][PYTHON] Make 
`__getitem__` support filter and select
ba39a6b4a60 is described below

commit ba39a6b4a60883708a2ed6e7e5c00a8649ddc66f
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Dec 30 16:11:02 2022 +0800

    [SPARK-41779][SPARK-41771][CONNECT][PYTHON] Make `__getitem__` support 
filter and select
    
    ### What changes were proposed in this pull request?
    Make dataframe `__getitem__` support:
    1, filter: `cdf[cdf.a.isin(1, 2, 3)]`
    2, select: `cdf[["col1", cdf.a]]`
    3, index: `cdf[0]`
    
    ### Why are the changes needed?
    to be consistent with 
[PySpark](https://github.com/apache/spark/blob/master/python/pyspark/sql/dataframe.py#L2764-L2825)
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added UT
    
    Closes #39300 from zhengruifeng/connect_df_getitem_filter.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/column.py               |  2 -
 python/pyspark/sql/connect/dataframe.py            | 29 +++++++--
 .../sql/tests/connect/test_connect_basic.py        | 73 ++++++++++++++++++++++
 3 files changed, 96 insertions(+), 8 deletions(-)

diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index 206d30b15d8..9be202145f2 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -458,8 +458,6 @@ def _test() -> None:
         #  the row
         del pyspark.sql.connect.column.Column.isNotNull.__doc__
         del pyspark.sql.connect.column.Column.isNull.__doc__
-        del pyspark.sql.connect.column.Column.isin.__doc__
-        # TODO(SPARK-41771): __getitem__ does not work with Column.isin
         del pyspark.sql.connect.column.Column.getField.__doc__
         del pyspark.sql.connect.column.Column.getItem.__doc__
 
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 5b5a6c3f4b5..256e63122ab 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -901,13 +901,30 @@ class DataFrame:
     def __getattr__(self, name: str) -> "Column":
         return self[name]
 
-    def __getitem__(self, name: str) -> "Column":
-        # Check for alias
-        alias = self._get_alias()
-        if alias is not None:
-            return col(alias)
+    @overload
+    def __getitem__(self, item: Union[int, str]) -> Column:
+        ...
+
+    @overload
+    def __getitem__(self, item: Union[Column, List, Tuple]) -> "DataFrame":
+        ...
+
+    def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> 
Union[Column, "DataFrame"]:
+        if isinstance(item, str):
+            # Check for alias
+            alias = self._get_alias()
+            if alias is not None:
+                return col(alias)
+            else:
+                return col(item)
+        elif isinstance(item, Column):
+            return self.filter(item)
+        elif isinstance(item, (list, tuple)):
+            return self.select(*item)
+        elif isinstance(item, int):
+            return col(self.columns[item])
         else:
-            return col(name)
+            raise TypeError("unexpected item type: %s" % type(item))
 
     def _print_plan(self) -> str:
         if self._plan:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 99ee54a87fa..9663f3123f9 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -117,6 +117,79 @@ class SparkConnectSQLTestCase(PandasOnSparkTestCase, 
ReusedPySparkTestCase, SQLT
 
 
 class SparkConnectTests(SparkConnectSQLTestCase):
+    def test_df_get_item(self):
+        # SPARK-41779: test __getitem__
+
+        query = """
+            SELECT * FROM VALUES
+            (true, 1, NULL), (false, NULL, 2.0), (NULL, 3, 3.0)
+            AS tab(a, b, c)
+            """
+
+        # +-----+----+----+
+        # |    a|   b|   c|
+        # +-----+----+----+
+        # | true|   1|null|
+        # |false|null| 2.0|
+        # | null|   3| 3.0|
+        # +-----+----+----+
+
+        cdf = self.connect.sql(query)
+        sdf = self.spark.sql(query)
+
+        # filter
+        self.assert_eq(
+            cdf[cdf.a].toPandas(),
+            sdf[sdf.a].toPandas(),
+        )
+        self.assert_eq(
+            cdf[cdf.b.isin(2, 3)].toPandas(),
+            sdf[sdf.b.isin(2, 3)].toPandas(),
+        )
+        self.assert_eq(
+            cdf[cdf.c > 1.5].toPandas(),
+            sdf[sdf.c > 1.5].toPandas(),
+        )
+
+        # select
+        self.assert_eq(
+            cdf[[cdf.a, "b", cdf.c]].toPandas(),
+            sdf[[sdf.a, "b", sdf.c]].toPandas(),
+        )
+        self.assert_eq(
+            cdf[(cdf.a, "b", cdf.c)].toPandas(),
+            sdf[(sdf.a, "b", sdf.c)].toPandas(),
+        )
+
+        # select by index
+        self.assertTrue(isinstance(cdf[0], Column))
+        self.assertTrue(isinstance(cdf[1], Column))
+        self.assertTrue(isinstance(cdf[2], Column))
+
+        self.assert_eq(
+            cdf[[cdf[0], cdf[1], cdf[2]]].toPandas(),
+            sdf[[sdf[0], sdf[1], sdf[2]]].toPandas(),
+        )
+
+        # check error
+        with self.assertRaisesRegex(
+            TypeError,
+            "unexpected item type",
+        ):
+            cdf[1.5]
+
+        with self.assertRaisesRegex(
+            TypeError,
+            "unexpected item type",
+        ):
+            cdf[None]
+
+        with self.assertRaisesRegex(
+            TypeError,
+            "unexpected item type",
+        ):
+            cdf[cdf]
+
     def test_error_handling(self):
         # SPARK-41533 Proper error handling for Spark Connect
         df = self.connect.range(10).select("id2")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to