This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7d320d784a2 [SPARK-41077][CONNECT][PYTHON][REFACTORING] Rename 
`ColumnRef` to `Column` in Python client implementation
7d320d784a2 is described below

commit 7d320d784a2d637fd1a8fd0798da3d2a39b4d7cd
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Fri Nov 11 11:03:04 2022 +0900

    [SPARK-41077][CONNECT][PYTHON][REFACTORING] Rename `ColumnRef` to `Column` 
in Python client implementation
    
    ### What changes were proposed in this pull request?
    
    Connect python client uses `ColumnRef` to represent columns in API (e.g. 
`df.name`). Current PySpark uses `Class Column` for the same thing. In this 
case, we can align Connect with PySpark, which can help existing PySpark users 
to reuse their code for Spark Connect python client as much as possible 
(minimize the code change).
    
    ### Why are the changes needed?
    
    This is to help existing PySpark users to reuse their code for Spark 
Connect python client as much as possible (minimize the code change).
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    Existing UT
    
    Closes #38586 from amaliujia/SPARK-41077.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/column.py               | 12 ++++----
 python/pyspark/sql/connect/dataframe.py            | 34 +++++++++++-----------
 python/pyspark/sql/connect/function_builder.py     |  4 +--
 python/pyspark/sql/connect/functions.py            |  6 ++--
 python/pyspark/sql/connect/plan.py                 | 14 ++++-----
 python/pyspark/sql/connect/typing/__init__.pyi     |  4 +--
 .../connect/test_connect_column_expressions.py     |  6 ++--
 7 files changed, 40 insertions(+), 40 deletions(-)

diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index 3c9f8c3d736..417bc7097de 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -30,8 +30,8 @@ if TYPE_CHECKING:
 
 def _bin_op(
     name: str, doc: str = "binary function", reverse: bool = False
-) -> Callable[["ColumnRef", Any], "Expression"]:
-    def _(self: "ColumnRef", other: Any) -> "Expression":
+) -> Callable[["Column", Any], "Expression"]:
+    def _(self: "Column", other: Any) -> "Expression":
         if isinstance(other, get_args(PrimitiveType)):
             other = LiteralExpression(other)
         if not reverse:
@@ -163,15 +163,15 @@ class LiteralExpression(Expression):
         return f"Literal({self._value})"
 
 
-class ColumnRef(Expression):
+class Column(Expression):
     """Represents a column reference. There is no guarantee that this column
     actually exists. In the context of this project, we refer by its name and
     treat it as an unresolved attribute. Attributes that have the same fully
     qualified name are identical"""
 
     @classmethod
-    def from_qualified_name(cls, name: str) -> "ColumnRef":
-        return ColumnRef(name)
+    def from_qualified_name(cls, name: str) -> "Column":
+        return Column(name)
 
     def __init__(self, name: str) -> None:
         super().__init__()
@@ -198,7 +198,7 @@ class ColumnRef(Expression):
 
 
 class SortOrder(Expression):
-    def __init__(self, col: ColumnRef, ascending: bool = True, nullsLast: bool 
= True) -> None:
+    def __init__(self, col: Column, ascending: bool = True, nullsLast: bool = 
True) -> None:
         super().__init__()
         self.ref = col
         self.ascending = ascending
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index e3116ea1250..0c19c67309d 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -31,7 +31,7 @@ import pandas
 
 import pyspark.sql.connect.plan as plan
 from pyspark.sql.connect.column import (
-    ColumnRef,
+    Column,
     Expression,
     LiteralExpression,
 )
@@ -44,7 +44,7 @@ if TYPE_CHECKING:
     from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString
     from pyspark.sql.connect.client import RemoteSparkSession
 
-ColumnOrName = Union[ColumnRef, str]
+ColumnOrName = Union[Column, str]
 
 
 class GroupingFrame(object):
@@ -52,9 +52,9 @@ class GroupingFrame(object):
     MeasuresType = Union[Sequence[Tuple["ExpressionOrString", str]], Dict[str, 
str]]
     OptMeasuresType = Optional[MeasuresType]
 
-    def __init__(self, df: "DataFrame", *grouping_cols: Union[ColumnRef, str]) 
-> None:
+    def __init__(self, df: "DataFrame", *grouping_cols: Union[Column, str]) -> 
None:
         self._df = df
-        self._grouping_cols = [x if isinstance(x, ColumnRef) else df[x] for x 
in grouping_cols]
+        self._grouping_cols = [x if isinstance(x, Column) else df[x] for x in 
grouping_cols]
 
     def agg(self, exprs: Optional[MeasuresType] = None) -> "DataFrame":
 
@@ -76,18 +76,18 @@ class GroupingFrame(object):
         )
         return res
 
-    def _map_cols_to_dict(self, fun: str, cols: List[Union[ColumnRef, str]]) 
-> Dict[str, str]:
+    def _map_cols_to_dict(self, fun: str, cols: List[Union[Column, str]]) -> 
Dict[str, str]:
         return {x if isinstance(x, str) else x.name(): fun for x in cols}
 
-    def min(self, *cols: Union[ColumnRef, str]) -> "DataFrame":
+    def min(self, *cols: Union[Column, str]) -> "DataFrame":
         expr = self._map_cols_to_dict("min", list(cols))
         return self.agg(expr)
 
-    def max(self, *cols: Union[ColumnRef, str]) -> "DataFrame":
+    def max(self, *cols: Union[Column, str]) -> "DataFrame":
         expr = self._map_cols_to_dict("max", list(cols))
         return self.agg(expr)
 
-    def sum(self, *cols: Union[ColumnRef, str]) -> "DataFrame":
+    def sum(self, *cols: Union[Column, str]) -> "DataFrame":
         expr = self._map_cols_to_dict("sum", list(cols))
         return self.agg(expr)
 
@@ -129,7 +129,7 @@ class DataFrame(object):
     def alias(self, alias: str) -> "DataFrame":
         return DataFrame.withPlan(plan.SubqueryAlias(self._plan, alias), 
session=self._session)
 
-    def approxQuantile(self, col: ColumnRef, probabilities: Any, 
relativeError: Any) -> "DataFrame":
+    def approxQuantile(self, col: Column, probabilities: Any, relativeError: 
Any) -> "DataFrame":
         ...
 
     def colRegex(self, regex: str) -> "DataFrame":
@@ -206,7 +206,7 @@ class DataFrame(object):
             self._session,
         )
 
-    def describe(self, cols: List[ColumnRef]) -> Any:
+    def describe(self, cols: List[Column]) -> Any:
         ...
 
     def dropDuplicates(self, subset: Optional[List[str]] = None) -> 
"DataFrame":
@@ -250,7 +250,7 @@ class DataFrame(object):
 
     def drop(self, *cols: "ColumnOrString") -> "DataFrame":
         all_cols = self.columns
-        dropped = set([c.name() if isinstance(c, ColumnRef) else 
self[c].name() for c in cols])
+        dropped = set([c.name() if isinstance(c, Column) else self[c].name() 
for c in cols])
         dropped_cols = filter(lambda x: x in dropped, all_cols)
         return DataFrame.withPlan(plan.Project(self._plan, *dropped_cols), 
session=self._session)
 
@@ -320,11 +320,11 @@ class DataFrame(object):
         """
         return self.limit(num).collect()
 
-    # TODO: extend `on` to also be type List[ColumnRef].
+    # TODO: extend `on` to also be type List[Column].
     def join(
         self,
         other: "DataFrame",
-        on: Optional[Union[str, List[str], ColumnRef]] = None,
+        on: Optional[Union[str, List[str], Column]] = None,
         how: Optional[str] = None,
     ) -> "DataFrame":
         if self._plan is None:
@@ -566,16 +566,16 @@ class DataFrame(object):
             p = p._child
         return None
 
-    def __getattr__(self, name: str) -> "ColumnRef":
+    def __getattr__(self, name: str) -> "Column":
         return self[name]
 
-    def __getitem__(self, name: str) -> "ColumnRef":
+    def __getitem__(self, name: str) -> "Column":
         # Check for alias
         alias = self._get_alias()
         if alias is not None:
-            return ColumnRef(alias)
+            return Column(alias)
         else:
-            return ColumnRef(name)
+            return Column(name)
 
     def _print_plan(self) -> str:
         if self._plan:
diff --git a/python/pyspark/sql/connect/function_builder.py 
b/python/pyspark/sql/connect/function_builder.py
index 9c519312a4f..e116e493954 100644
--- a/python/pyspark/sql/connect/function_builder.py
+++ b/python/pyspark/sql/connect/function_builder.py
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Optional, Any, Iterable, 
Union
 import pyspark.sql.connect.proto as proto
 import pyspark.sql.types
 from pyspark.sql.connect.column import (
-    ColumnRef,
+    Column,
     Expression,
     ScalarFunctionExpression,
 )
@@ -45,7 +45,7 @@ def _build(name: str, *args: "ExpressionOrString") -> 
ScalarFunctionExpression:
     -------
     :class:`ScalarFunctionExpression`
     """
-    cols = [x if isinstance(x, Expression) else 
ColumnRef.from_qualified_name(x) for x in args]
+    cols = [x if isinstance(x, Expression) else Column.from_qualified_name(x) 
for x in args]
     return ScalarFunctionExpression(name, *cols)
 
 
diff --git a/python/pyspark/sql/connect/functions.py 
b/python/pyspark/sql/connect/functions.py
index 880096da459..00d0a56aedb 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -14,15 +14,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-from pyspark.sql.connect.column import ColumnRef, LiteralExpression
+from pyspark.sql.connect.column import Column, LiteralExpression
 
 from typing import Any
 
 # TODO(SPARK-40538) Add support for the missing PySpark functions.
 
 
-def col(x: str) -> ColumnRef:
-    return ColumnRef(x)
+def col(x: str) -> Column:
+    return Column(x)
 
 
 def lit(x: Any) -> LiteralExpression:
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 926119c5457..e5eed195568 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -28,7 +28,7 @@ from typing import (
 
 import pyspark.sql.connect.proto as proto
 from pyspark.sql.connect.column import (
-    ColumnRef,
+    Column,
     Expression,
     SortOrder,
 )
@@ -64,7 +64,7 @@ class LogicalPlan(object):
         if type(col) is str:
             return self.unresolved_attr(col)
         else:
-            return cast(ColumnRef, col).to_plan(session)
+            return cast(Column, col).to_plan(session)
 
     def plan(self, session: "RemoteSparkSession") -> proto.Relation:
         ...
@@ -360,7 +360,7 @@ class Sort(LogicalPlan):
     def __init__(
         self,
         child: Optional["LogicalPlan"],
-        columns: List[Union[SortOrder, ColumnRef, str]],
+        columns: List[Union[SortOrder, Column, str]],
         is_global: bool,
     ) -> None:
         super().__init__(child)
@@ -368,7 +368,7 @@ class Sort(LogicalPlan):
         self.is_global = is_global
 
     def col_to_sort_field(
-        self, col: Union[SortOrder, ColumnRef, str], session: 
"RemoteSparkSession"
+        self, col: Union[SortOrder, Column, str], session: "RemoteSparkSession"
     ) -> proto.Sort.SortField:
         if isinstance(col, SortOrder):
             sf = proto.Sort.SortField()
@@ -387,7 +387,7 @@ class Sort(LogicalPlan):
         else:
             sf = proto.Sort.SortField()
             # Check string
-            if isinstance(col, ColumnRef):
+            if isinstance(col, Column):
                 sf.expression.CopyFrom(col.to_plan(session))
             else:
                 sf.expression.CopyFrom(self.unresolved_attr(col))
@@ -478,7 +478,7 @@ class Aggregate(LogicalPlan):
     def __init__(
         self,
         child: Optional["LogicalPlan"],
-        grouping_cols: List[ColumnRef],
+        grouping_cols: List[Column],
         measures: OptMeasuresType,
     ) -> None:
         super().__init__(child)
@@ -532,7 +532,7 @@ class Join(LogicalPlan):
         self,
         left: Optional["LogicalPlan"],
         right: "LogicalPlan",
-        on: Optional[Union[str, List[str], ColumnRef]],
+        on: Optional[Union[str, List[str], Column]],
         how: Optional[str],
     ) -> None:
         super().__init__(left)
diff --git a/python/pyspark/sql/connect/typing/__init__.pyi 
b/python/pyspark/sql/connect/typing/__init__.pyi
index d8f8e300324..6c67b561311 100644
--- a/python/pyspark/sql/connect/typing/__init__.pyi
+++ b/python/pyspark/sql/connect/typing/__init__.pyi
@@ -17,12 +17,12 @@
 
 from typing_extensions import Protocol
 from typing import Union
-from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, 
ColumnRef
+from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, 
Column
 from pyspark.sql.connect.function_builder import UserDefinedFunction
 
 ExpressionOrString = Union[str, Expression]
 
-ColumnOrString = Union[str, ColumnRef]
+ColumnOrString = Union[str, Column]
 
 class FunctionBuilderCallable(Protocol):
     def __call__(self, *_: ExpressionOrString) -> ScalarFunctionExpression: ...
diff --git 
a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py 
b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
index ca75b14bb67..59e3c97679e 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
@@ -36,11 +36,11 @@ class 
SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
         df = self.connect.with_plan(p.Read("table"))
 
         c1 = df.col_name
-        self.assertIsInstance(c1, col.ColumnRef)
+        self.assertIsInstance(c1, col.Column)
         c2 = df["col_name"]
-        self.assertIsInstance(c2, col.ColumnRef)
+        self.assertIsInstance(c2, col.Column)
         c3 = fun.col("col_name")
-        self.assertIsInstance(c3, col.ColumnRef)
+        self.assertIsInstance(c3, col.Column)
 
         # All Protos should be identical
         cp1 = c1.to_plan(None)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to