This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2a361b9ddfa [SPARK-41002][CONNECT][PYTHON] Compatible `take`, `head` and `first` API in Python client 2a361b9ddfa is described below commit 2a361b9ddfa766c719399b35c38f4dafe68353ee Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Tue Nov 8 08:30:49 2022 +0800 [SPARK-41002][CONNECT][PYTHON] Compatible `take`, `head` and `first` API in Python client ### What changes were proposed in this pull request? 1. Add `take(n)` API. 2. Change `head(n)` API to return `Union[Optional[Row], List[Row]]`. 3. Update `first()` to return `Optional[Row]`. ### Why are the changes needed? Improve API coverage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #38488 from amaliujia/SPARK-41002. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/dataframe.py | 61 ++++++++++++++++++++-- .../sql/tests/connect/test_connect_basic.py | 36 +++++++++++-- 2 files changed, 90 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index b9ba4b99ba0..9eecdbb7145 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -24,6 +24,7 @@ from typing import ( Tuple, Union, TYPE_CHECKING, + overload, ) import pandas @@ -211,14 +212,66 @@ class DataFrame(object): plan.Filter(child=self._plan, filter=condition), session=self._session ) - def first(self) -> Optional["pandas.DataFrame"]: - return self.head(1) + def first(self) -> Optional[Row]: + """Returns the first row as a :class:`Row`. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`Row` + First row if :class:`DataFrame` is not empty, otherwise ``None``. + """ + return self.head() def groupBy(self, *cols: "ColumnOrString") -> GroupingFrame: return GroupingFrame(self, *cols) - def head(self, n: int) -> Optional["pandas.DataFrame"]: - return self.limit(n).toPandas() + @overload + def head(self) -> Optional[Row]: + ... + + @overload + def head(self, n: int) -> List[Row]: + ... + + def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]: + """Returns the first ``n`` rows. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + n : int, optional + default 1. Number of rows to return. + + Returns + ------- + If n is greater than 1, return a list of :class:`Row`. + If n is 1, return a single Row. + """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + + def take(self, num: int) -> List[Row]: + """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + num : int + Number of records to return. Will return this number of records + or whataver number is available. + + Returns + ------- + list + List of rows + """ + return self.limit(num).collect() # TODO: extend `on` to also be type List[ColumnRef]. def join( diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 18a752ee19d..a0f046907f7 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -46,6 +46,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase): if have_pandas: connect: RemoteSparkSession tbl_name: str + tbl_name_empty: str df_text: "DataFrame" @classmethod @@ -61,6 +62,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase): cls.df_text = cls.sc.parallelize(cls.testDataStr).toDF() cls.tbl_name = "test_connect_basic_table_1" + cls.tbl_name_empty = "test_connect_basic_table_empty" # Cleanup test data cls.spark_connect_clean_up_test_data() @@ -80,10 +82,21 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase): # Since we might create multiple Spark sessions, we need to create global temporary view # that is specifically maintained in the "global_temp" schema. df.write.saveAsTable(cls.tbl_name) + empty_table_schema = StructType( + [ + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), + ] + ) + emptyRDD = cls.spark.sparkContext.emptyRDD() + empty_df = cls.spark.createDataFrame(emptyRDD, empty_table_schema) + empty_df.write.saveAsTable(cls.tbl_name_empty) @classmethod def spark_connect_clean_up_test_data(cls: Any) -> None: cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name)) + cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty)) class SparkConnectTests(SparkConnectSQLTestCase): @@ -145,10 +158,27 @@ class SparkConnectTests(SparkConnectSQLTestCase): self.assertEqual(1, len(pdf.index)) def test_head(self): + # SPARK-41002: test `head` API in Python Client + df = self.connect.read.table(self.tbl_name) + self.assertIsNotNone(len(df.head())) + self.assertIsNotNone(len(df.head(1))) + self.assertIsNotNone(len(df.head(5))) + df2 = self.connect.read.table(self.tbl_name_empty) + self.assertIsNone(df2.head()) + + def test_first(self): + # SPARK-41002: test `first` API in Python Client + df = self.connect.read.table(self.tbl_name) + self.assertIsNotNone(len(df.first())) + df2 = self.connect.read.table(self.tbl_name_empty) + self.assertIsNone(df2.first()) + + def test_take(self) -> None: + # SPARK-41002: test `take` API in Python Client df = self.connect.read.table(self.tbl_name) - pd = df.head(10) - self.assertIsNotNone(pd) - self.assertEqual(10, len(pd.index)) + self.assertEqual(5, len(df.take(5))) + df2 = self.connect.read.table(self.tbl_name_empty) + self.assertEqual(0, len(df2.take(5))) def test_range(self): self.assertTrue( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org