This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 764cb3e07364 [SPARK-46750][CONNECT][PYTHON] DataFrame APIs code clean up 764cb3e07364 is described below commit 764cb3e073644b9d543502d8951c47e41ba0f46b Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Thu Jan 18 08:17:29 2024 +0800 [SPARK-46750][CONNECT][PYTHON] DataFrame APIs code clean up ### What changes were proposed in this pull request? 1, unify the import; 2, delete unused helper functions and variables; ### Why are the changes needed? code clean up ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #44771 from zhengruifeng/py_df_cleanup. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/dataframe.py | 38 ++++++++++----------------------- python/pyspark/sql/connect/group.py | 15 ++++--------- 2 files changed, 15 insertions(+), 38 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 7ee27065208c..0cf6c0921f78 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -72,19 +72,12 @@ from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2 from pyspark.sql.connect.streaming.readwriter import DataStreamWriter from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import ( + SortOrder, ColumnReference, UnresolvedRegex, UnresolvedStar, ) -from pyspark.sql.connect.functions.builtin import ( - _to_col, - _invoke_function, - col, - lit, - udf, - struct, - expr as sql_expression, -) +from pyspark.sql.connect.functions import builtin as F from pyspark.sql.pandas.types import from_arrow_schema @@ -199,9 +192,9 @@ class DataFrame: expr = expr[0] # type: ignore[assignment] for element in expr: if isinstance(element, str): - sql_expr.append(sql_expression(element)) + sql_expr.append(F.expr(element)) else: - sql_expr.extend([sql_expression(e) for e in element]) + sql_expr.extend([F.expr(e) for e in element]) return DataFrame(plan.Project(self._plan, *sql_expr), session=self._session) @@ -215,7 +208,7 @@ class DataFrame: ) if len(exprs) == 1 and isinstance(exprs[0], dict): - measures = [_invoke_function(f, col(e)) for e, f in exprs[0].items()] + measures = [F._invoke_function(f, F.col(e)) for e, f in exprs[0].items()] return self.groupBy().agg(*measures) else: # other expressions @@ -259,7 +252,7 @@ class DataFrame: sparkSession.__doc__ = PySparkDataFrame.sparkSession.__doc__ def count(self) -> int: - table, _ = self.agg(_invoke_function("count", lit(1)))._to_table() + table, _ = self.agg(F._invoke_function("count", F.lit(1)))._to_table() return table[0][0].as_py() count.__doc__ = PySparkDataFrame.count.__doc__ @@ -352,8 +345,6 @@ class DataFrame: self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName" ) -> "DataFrame": def _convert_col(col: "ColumnOrName") -> "ColumnOrName": - from pyspark.sql.connect.expressions import SortOrder, ColumnReference - if isinstance(col, Column): if isinstance(col._expr, SortOrder): return col @@ -471,7 +462,7 @@ class DataFrame: def filter(self, condition: Union[Column, str]) -> "DataFrame": if isinstance(condition, str): - expr = sql_expression(condition) + expr = F.expr(condition) else: expr = condition return DataFrame(plan.Filter(child=self._plan, filter=expr), session=self._session) @@ -713,7 +704,7 @@ class DataFrame: ) else: _c = c # type: ignore[assignment] - _cols.append(_to_col(cast("ColumnOrName", _c))) + _cols.append(F._to_col(cast("ColumnOrName", _c))) ascending = kwargs.get("ascending", True) if isinstance(ascending, (bool, int)): @@ -1652,8 +1643,6 @@ class DataFrame: def sampleBy( self, col: "ColumnOrName", fractions: Dict[Any, float], seed: Optional[int] = None ) -> "DataFrame": - from pyspark.sql.connect.expressions import ColumnReference - if isinstance(col, str): col = Column(ColumnReference(col)) elif not isinstance(col, Column): @@ -1754,7 +1743,7 @@ class DataFrame: elif isinstance(item, (list, tuple)): return self.select(*item) elif isinstance(item, int): - return col(self.columns[item]) + return F.col(self.columns[item]) else: raise PySparkTypeError( error_class="NOT_COLUMN_OR_INT_OR_LIST_OR_STR_OR_TUPLE", @@ -1768,11 +1757,6 @@ class DataFrame: __dir__.__doc__ = PySparkDataFrame.__dir__.__doc__ - def _print_plan(self) -> str: - if self._plan: - return self._plan.print() - return "" - def collect(self) -> List[Row]: table, schema = self._to_table() @@ -2084,8 +2068,8 @@ class DataFrame: def foreach_func(row: Any) -> None: f(row) - self.select(struct(*self.schema.fieldNames()).alias("row")).select( - udf(foreach_func, StructType())("row") # type: ignore[arg-type] + self.select(F.struct(*self.schema.fieldNames()).alias("row")).select( + F.udf(foreach_func, StructType())("row") # type: ignore[arg-type] ).collect() foreach.__doc__ = PySparkDataFrame.foreach.__doc__ diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index 2ccd7463b9e0..db4c9f57c5c2 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -40,7 +40,7 @@ from pyspark.sql.types import StructType import pyspark.sql.connect.plan as plan from pyspark.sql.connect.column import Column -from pyspark.sql.connect.functions.builtin import _invoke_function, col, lit +from pyspark.sql.connect.functions import builtin as F from pyspark.errors import PySparkNotImplementedError, PySparkTypeError if TYPE_CHECKING: @@ -132,7 +132,7 @@ class GroupedData: assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): # Convert the dict into key value pairs - aggregate_cols = [_invoke_function(exprs[0][k], col(k)) for k in exprs[0]] + aggregate_cols = [F._invoke_function(exprs[0][k], F.col(k)) for k in exprs[0]] else: # Columns assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" @@ -166,8 +166,6 @@ class GroupedData: field.name for field in schema.fields if isinstance(field.dataType, NumericType) ] - agg_cols: List[str] = [] - if len(cols) > 0: invalid_cols = [c for c in cols if c not in numerical_cols] if len(invalid_cols) > 0: @@ -185,7 +183,7 @@ class GroupedData: child=self._df._plan, group_type=self._group_type, grouping_cols=self._grouping_cols, - aggregate_cols=[_invoke_function(function, col(c)) for c in agg_cols], + aggregate_cols=[F._invoke_function(function, F.col(c)) for c in agg_cols], pivot_col=self._pivot_col, pivot_values=self._pivot_values, grouping_sets=self._grouping_sets, @@ -216,7 +214,7 @@ class GroupedData: mean = avg def count(self) -> "DataFrame": - return self.agg(_invoke_function("count", lit(1)).alias("count")) + return self.agg(F._invoke_function("count", F.lit(1)).alias("count")) count.__doc__ = PySparkGroupedData.count.__doc__ @@ -444,11 +442,6 @@ class PandasCogroupedOps: applyInArrow.__doc__ = PySparkPandasCogroupedOps.applyInArrow.__doc__ - @staticmethod - def _extract_cols(gd: "GroupedData") -> List[Column]: - df = gd._df - return [df[col] for col in df.columns] - PandasCogroupedOps.__doc__ = PySparkPandasCogroupedOps.__doc__ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org