This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e2a588c91adc [SPARK-48438][PS][CONNECT] Directly use the parent column 
class
e2a588c91adc is described below

commit e2a588c91adc335c34316def0863403330966a30
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed May 29 09:45:34 2024 +0900

    [SPARK-48438][PS][CONNECT] Directly use the parent column class
    
    ### What changes were proposed in this pull request?
    Directly use the parent column class
    
    ### Why are the changes needed?
    the parent column class works with both Spark Classic and Spark Connect, no 
need to use `get_column_class` any more
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #46775 from zhengruifeng/ps_column_cleanup.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/pandas/data_type_ops/base.py        | 11 ++---
 python/pyspark/pandas/data_type_ops/boolean_ops.py | 16 +++----
 python/pyspark/pandas/data_type_ops/date_ops.py    | 15 +++----
 python/pyspark/pandas/data_type_ops/num_ops.py     | 50 ++++++++--------------
 python/pyspark/pandas/frame.py                     | 18 ++++----
 python/pyspark/pandas/indexes/multi.py             | 11 ++---
 python/pyspark/pandas/indexing.py                  | 24 ++++-------
 python/pyspark/pandas/internal.py                  | 30 ++++++-------
 python/pyspark/pandas/namespace.py                 | 16 +++----
 python/pyspark/pandas/series.py                    |  7 ++-
 python/pyspark/pandas/spark/accessors.py           |  8 ++--
 python/pyspark/sql/utils.py                        | 20 +--------
 12 files changed, 82 insertions(+), 144 deletions(-)

diff --git a/python/pyspark/pandas/data_type_ops/base.py 
b/python/pyspark/pandas/data_type_ops/base.py
index 2df40252965b..b4a6b1abbcaf 100644
--- a/python/pyspark/pandas/data_type_ops/base.py
+++ b/python/pyspark/pandas/data_type_ops/base.py
@@ -24,7 +24,7 @@ import numpy as np
 import pandas as pd
 from pandas.api.types import CategoricalDtype
 
-from pyspark.sql import functions as F
+from pyspark.sql import functions as F, Column as PySparkColumn
 from pyspark.sql.types import (
     ArrayType,
     BinaryType,
@@ -53,9 +53,6 @@ from pyspark.pandas.typedef.typehints import (
     spark_type_to_pandas_dtype,
 )
 
-# For supporting Spark Connect
-from pyspark.sql.utils import get_column_class
-
 if extension_dtypes_available:
     from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype
 
@@ -485,16 +482,14 @@ class DataTypeOps(object, metaclass=ABCMeta):
         else:
             from pyspark.pandas.base import column_op
 
-            Column = get_column_class()
-            return column_op(Column.__eq__)(left, right)
+            return column_op(PySparkColumn.__eq__)(left, right)
 
     def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         from pyspark.pandas.base import column_op
 
         _sanitize_list_like(right)
 
-        Column = get_column_class()
-        return column_op(Column.__ne__)(left, right)
+        return column_op(PySparkColumn.__ne__)(left, right)
 
     def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
         raise TypeError("Unary ~ can not be applied to %s." % self.pretty_name)
diff --git a/python/pyspark/pandas/data_type_ops/boolean_ops.py 
b/python/pyspark/pandas/data_type_ops/boolean_ops.py
index 7e7ea7eb0738..c91dcc913080 100644
--- a/python/pyspark/pandas/data_type_ops/boolean_ops.py
+++ b/python/pyspark/pandas/data_type_ops/boolean_ops.py
@@ -35,10 +35,8 @@ from pyspark.pandas.data_type_ops.base import (
     _is_boolean_type,
 )
 from pyspark.pandas.typedef.typehints import as_spark_type, extension_dtypes, 
pandas_on_spark_type
-from pyspark.sql import functions as F
-from pyspark.sql.column import Column as PySparkColumn
+from pyspark.sql import functions as F, Column as PySparkColumn
 from pyspark.sql.types import BooleanType, StringType
-from pyspark.sql.utils import get_column_class
 from pyspark.errors import PySparkValueError
 
 
@@ -331,23 +329,19 @@ class BooleanOps(DataTypeOps):
 
     def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        Column = get_column_class()
-        return column_op(Column.__lt__)(left, right)
+        return column_op(PySparkColumn.__lt__)(left, right)
 
     def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        Column = get_column_class()
-        return column_op(Column.__le__)(left, right)
+        return column_op(PySparkColumn.__le__)(left, right)
 
     def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        Column = get_column_class()
-        return column_op(Column.__ge__)(left, right)
+        return column_op(PySparkColumn.__ge__)(left, right)
 
     def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
-        Column = get_column_class()
-        return column_op(Column.__gt__)(left, right)
+        return column_op(PySparkColumn.__gt__)(left, right)
 
     def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
         return operand._with_new_scol(~operand.spark.column, 
field=operand._internal.data_fields[0])
diff --git a/python/pyspark/pandas/data_type_ops/date_ops.py 
b/python/pyspark/pandas/data_type_ops/date_ops.py
index 771b5d38a17a..9a0b82de6ce8 100644
--- a/python/pyspark/pandas/data_type_ops/date_ops.py
+++ b/python/pyspark/pandas/data_type_ops/date_ops.py
@@ -23,9 +23,8 @@ import numpy as np
 import pandas as pd
 from pandas.api.types import CategoricalDtype
 
-from pyspark.sql import functions as F
+from pyspark.sql import functions as F, Column as PySparkColumn
 from pyspark.sql.types import BooleanType, DateType, StringType
-from pyspark.sql.utils import get_column_class
 from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
 from pyspark.pandas.base import column_op, IndexOpsMixin
 from pyspark.pandas.data_type_ops.base import (
@@ -84,29 +83,25 @@ class DateOps(DataTypeOps):
         from pyspark.pandas.base import column_op
 
         _sanitize_list_like(right)
-        Column = get_column_class()
-        return column_op(Column.__lt__)(left, right)
+        return column_op(PySparkColumn.__lt__)(left, right)
 
     def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         from pyspark.pandas.base import column_op
 
         _sanitize_list_like(right)
-        Column = get_column_class()
-        return column_op(Column.__le__)(left, right)
+        return column_op(PySparkColumn.__le__)(left, right)
 
     def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         from pyspark.pandas.base import column_op
 
         _sanitize_list_like(right)
-        Column = get_column_class()
-        return column_op(Column.__ge__)(left, right)
+        return column_op(PySparkColumn.__ge__)(left, right)
 
     def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         from pyspark.pandas.base import column_op
 
         _sanitize_list_like(right)
-        Column = get_column_class()
-        return column_op(Column.__gt__)(left, right)
+        return column_op(PySparkColumn.__gt__)(left, right)
 
     def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) 
-> IndexOpsLike:
         dtype, spark_type = pandas_on_spark_type(dtype)
diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py 
b/python/pyspark/pandas/data_type_ops/num_ops.py
index 6f393c9652d7..8e8dfee9990e 100644
--- a/python/pyspark/pandas/data_type_ops/num_ops.py
+++ b/python/pyspark/pandas/data_type_ops/num_ops.py
@@ -43,8 +43,7 @@ from pyspark.pandas.data_type_ops.base import (
     _is_boolean_type,
 )
 from pyspark.pandas.typedef.typehints import extension_dtypes, 
pandas_on_spark_type
-from pyspark.sql import functions as F
-from pyspark.sql import Column as PySparkColumn
+from pyspark.sql import functions as F, Column as PySparkColumn
 from pyspark.sql.types import (
     BooleanType,
     DataType,
@@ -53,7 +52,7 @@ from pyspark.sql.types import (
 from pyspark.errors import PySparkValueError
 
 # For Supporting Spark Connect
-from pyspark.sql.utils import pyspark_column_op, get_column_class
+from pyspark.sql.utils import pyspark_column_op
 
 
 def _non_fractional_astype(
@@ -82,8 +81,7 @@ class NumericOps(DataTypeOps):
             raise TypeError("Addition can not be applied to given types.")
 
         right = transform_boolean_operand_to_numeric(right, 
spark_type=left.spark.data_type)
-        Column = get_column_class()
-        return column_op(Column.__add__)(left, right)
+        return column_op(PySparkColumn.__add__)(left, right)
 
     def sub(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
@@ -91,8 +89,7 @@ class NumericOps(DataTypeOps):
             raise TypeError("Subtraction can not be applied to given types.")
 
         right = transform_boolean_operand_to_numeric(right, 
spark_type=left.spark.data_type)
-        Column = get_column_class()
-        return column_op(Column.__sub__)(left, right)
+        return column_op(PySparkColumn.__sub__)(left, right)
 
     def mod(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
@@ -110,13 +107,11 @@ class NumericOps(DataTypeOps):
         if not is_valid_operand_for_numeric_arithmetic(right):
             raise TypeError("Exponentiation can not be applied to given 
types.")
 
-        Column = get_column_class()
-
-        def pow_func(left: Column, right: Any) -> Column:  # type: 
ignore[valid-type]
+        def pow_func(left: PySparkColumn, right: Any) -> PySparkColumn:
             return (
-                F.when(left == 1, left)  # type: ignore
+                F.when(left == 1, left)
                 .when(F.lit(right) == 0, 1)
-                .otherwise(Column.__pow__(left, right))
+                .otherwise(PySparkColumn.__pow__(left, right))
             )
 
         right = transform_boolean_operand_to_numeric(right, 
spark_type=left.spark.data_type)
@@ -127,34 +122,29 @@ class NumericOps(DataTypeOps):
         if not isinstance(right, numbers.Number):
             raise TypeError("Addition can not be applied to given types.")
         right = transform_boolean_operand_to_numeric(right)
-        Column = get_column_class()
-        return column_op(Column.__radd__)(left, right)
+        return column_op(PySparkColumn.__radd__)(left, right)
 
     def rsub(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
         if not isinstance(right, numbers.Number):
             raise TypeError("Subtraction can not be applied to given types.")
         right = transform_boolean_operand_to_numeric(right)
-        Column = get_column_class()
-        return column_op(Column.__rsub__)(left, right)
+        return column_op(PySparkColumn.__rsub__)(left, right)
 
     def rmul(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
         if not isinstance(right, numbers.Number):
             raise TypeError("Multiplication can not be applied to given 
types.")
         right = transform_boolean_operand_to_numeric(right)
-        Column = get_column_class()
-        return column_op(Column.__rmul__)(left, right)
+        return column_op(PySparkColumn.__rmul__)(left, right)
 
     def rpow(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
         if not isinstance(right, numbers.Number):
             raise TypeError("Exponentiation can not be applied to given 
types.")
 
-        Column = get_column_class()
-
-        def rpow_func(left: Column, right: Any) -> Column:  # type: 
ignore[valid-type]
-            return F.when(F.lit(right == 1), 
right).otherwise(Column.__rpow__(left, right))
+        def rpow_func(left: PySparkColumn, right: Any) -> PySparkColumn:
+            return F.when(F.lit(right == 1), 
right).otherwise(PySparkColumn.__rpow__(left, right))
 
         right = transform_boolean_operand_to_numeric(right)
         return column_op(rpow_func)(left, right)
@@ -250,8 +240,8 @@ class IntegralOps(NumericOps):
             raise TypeError("Multiplication can not be applied to given 
types.")
 
         right = transform_boolean_operand_to_numeric(right, 
spark_type=left.spark.data_type)
-        Column = get_column_class()
-        return column_op(Column.__mul__)(left, right)
+
+        return column_op(PySparkColumn.__mul__)(left, right)
 
     def truediv(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
@@ -335,8 +325,8 @@ class FractionalOps(NumericOps):
             raise TypeError("Multiplication can not be applied to given 
types.")
 
         right = transform_boolean_operand_to_numeric(right, 
spark_type=left.spark.data_type)
-        Column = get_column_class()
-        return column_op(Column.__mul__)(left, right)
+
+        return column_op(PySparkColumn.__mul__)(left, right)
 
     def truediv(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
         _sanitize_list_like(right)
@@ -496,13 +486,11 @@ class DecimalOps(FractionalOps):
         if not isinstance(right, numbers.Number):
             raise TypeError("Exponentiation can not be applied to given 
types.")
 
-        Column = get_column_class()
-
-        def rpow_func(left: Column, right: Any) -> Column:  # type: 
ignore[valid-type]
+        def rpow_func(left: PySparkColumn, right: Any) -> PySparkColumn:
             return (
-                F.when(left.isNull(), np.nan)  # type: ignore
+                F.when(left.isNull(), np.nan)
                 .when(F.lit(right == 1), right)
-                .otherwise(Column.__rpow__(left, right))
+                .otherwise(PySparkColumn.__rpow__(left, right))
             )
 
         right = transform_boolean_operand_to_numeric(right)
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 3669512d8371..2bf91c2be2f3 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -149,7 +149,7 @@ from pyspark.pandas.typedef.typehints import (
     create_tuple_for_frame_type,
 )
 from pyspark.pandas.plot import PandasOnSparkPlotAccessor
-from pyspark.sql.utils import get_column_class, get_dataframe_class
+from pyspark.sql.utils import get_dataframe_class
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import OptionalPrimitiveType
@@ -5629,10 +5629,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         from pyspark.pandas.indexes import MultiIndex
         from pyspark.pandas.series import IndexOpsMixin
 
-        Column = get_column_class()
         for k, v in kwargs.items():
             is_invalid_assignee = (
-                not (isinstance(v, (IndexOpsMixin, Column)) or callable(v) or 
is_scalar(v))
+                not (isinstance(v, (IndexOpsMixin, PySparkColumn)) or 
callable(v) or is_scalar(v))
             ) or isinstance(v, MultiIndex)
             if is_invalid_assignee:
                 raise TypeError(
@@ -5646,7 +5645,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
                 (v.spark.column, v._internal.data_fields[0])
                 if isinstance(v, IndexOpsMixin) and not isinstance(v, 
MultiIndex)
                 else (v, None)
-                if isinstance(v, Column)
+                if isinstance(v, PySparkColumn)
                 else (F.lit(v), None)
             )
             for k, v in kwargs.items()
@@ -7689,21 +7688,20 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         if na_position not in ("first", "last"):
             raise ValueError("invalid na_position: '{}'".format(na_position))
 
-        Column = get_column_class()
         # Mapper: Get a spark colum
         # n function for (ascending, na_position) combination
         mapper = {
-            (True, "first"): Column.asc_nulls_first,
-            (True, "last"): Column.asc_nulls_last,
-            (False, "first"): Column.desc_nulls_first,
-            (False, "last"): Column.desc_nulls_last,
+            (True, "first"): PySparkColumn.asc_nulls_first,
+            (True, "last"): PySparkColumn.asc_nulls_last,
+            (False, "first"): PySparkColumn.desc_nulls_first,
+            (False, "last"): PySparkColumn.desc_nulls_last,
         }
         by = [mapper[(asc, na_position)](scol) for scol, asc in zip(by, 
ascending)]
 
         natural_order_scol = F.col(NATURAL_ORDER_COLUMN_NAME)
 
         if keep == "last":
-            natural_order_scol = Column.desc(natural_order_scol)
+            natural_order_scol = PySparkColumn.desc(natural_order_scol)
         elif keep == "all":
             raise NotImplementedError("`keep`=all is not implemented yet.")
         elif keep != "first":
diff --git a/python/pyspark/pandas/indexes/multi.py 
b/python/pyspark/pandas/indexes/multi.py
index 7d2712cbb531..b5aae890d50a 100644
--- a/python/pyspark/pandas/indexes/multi.py
+++ b/python/pyspark/pandas/indexes/multi.py
@@ -23,7 +23,6 @@ from pandas.api.types import is_hashable, is_list_like  # 
type: ignore[attr-defi
 
 from pyspark.sql import functions as F, Column as PySparkColumn, Window
 from pyspark.sql.types import DataType
-from pyspark.sql.utils import get_column_class
 from pyspark import pandas as ps
 from pyspark.pandas._typing import Label, Name, Scalar
 from pyspark.pandas.exceptions import PandasNotImplementedError
@@ -514,7 +513,6 @@ class MultiIndex(Index):
 
         cond = F.lit(True)
         has_not_null = F.lit(True)
-        Column = get_column_class()
         for scol in self._internal.index_spark_columns[::-1]:
             data_type = self._internal.spark_type_for(scol)
             prev = F.lag(scol, 1).over(window)
@@ -522,7 +520,9 @@ class MultiIndex(Index):
             # Since pandas 1.1.4, null value is not allowed at any levels of 
MultiIndex.
             # Therefore, we should check `has_not_null` over all levels.
             has_not_null = has_not_null & scol.isNotNull()
-            cond = F.when(scol.eqNullSafe(prev), cond).otherwise(compare(scol, 
prev, Column.__gt__))
+            cond = F.when(scol.eqNullSafe(prev), cond).otherwise(
+                compare(scol, prev, PySparkColumn.__gt__)
+            )
 
         cond = has_not_null & (prev.isNull() | cond)
 
@@ -560,7 +560,6 @@ class MultiIndex(Index):
 
         cond = F.lit(True)
         has_not_null = F.lit(True)
-        Column = get_column_class()
         for scol in self._internal.index_spark_columns[::-1]:
             data_type = self._internal.spark_type_for(scol)
             prev = F.lag(scol, 1).over(window)
@@ -568,7 +567,9 @@ class MultiIndex(Index):
             # Since pandas 1.1.4, null value is not allowed at any levels of 
MultiIndex.
             # Therefore, we should check `has_not_null` over all levels.
             has_not_null = has_not_null & scol.isNotNull()
-            cond = F.when(scol.eqNullSafe(prev), cond).otherwise(compare(scol, 
prev, Column.__lt__))
+            cond = F.when(scol.eqNullSafe(prev), cond).otherwise(
+                compare(scol, prev, PySparkColumn.__lt__)
+            )
 
         cond = has_not_null & (prev.isNull() | cond)
 
diff --git a/python/pyspark/pandas/indexing.py 
b/python/pyspark/pandas/indexing.py
index 24b7c53eea99..fada94cf383a 100644
--- a/python/pyspark/pandas/indexing.py
+++ b/python/pyspark/pandas/indexing.py
@@ -50,7 +50,6 @@ from pyspark.pandas.utils import (
     spark_column_equals,
     verify_temp_column_name,
 )
-from pyspark.sql.utils import get_column_class
 
 if TYPE_CHECKING:
     from pyspark.pandas.frame import DataFrame
@@ -259,12 +258,11 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
         """
         from pyspark.pandas.series import Series
 
-        Column = get_column_class()
         if rows_sel is None:
             return None, None, None
         elif isinstance(rows_sel, Series):
             return self._select_rows_by_series(rows_sel)
-        elif isinstance(rows_sel, Column):
+        elif isinstance(rows_sel, PySparkColumn):
             return self._select_rows_by_spark_column(rows_sel)
         elif isinstance(rows_sel, slice):
             if rows_sel == slice(None):
@@ -306,7 +304,6 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
         """
         from pyspark.pandas.series import Series
 
-        Column = get_column_class()
         if cols_sel is None:
             column_labels = self._internal.column_labels
             data_spark_columns = self._internal.data_spark_columns
@@ -314,7 +311,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
             return column_labels, data_spark_columns, data_fields, False, None
         elif isinstance(cols_sel, Series):
             return self._select_cols_by_series(cols_sel, missing_keys)
-        elif isinstance(cols_sel, Column):
+        elif isinstance(cols_sel, PySparkColumn):
             return self._select_cols_by_spark_column(cols_sel, missing_keys)
         elif isinstance(cols_sel, slice):
             if cols_sel == slice(None):
@@ -579,7 +576,6 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
         from pyspark.pandas.frame import DataFrame
         from pyspark.pandas.series import Series, first_series
 
-        Column = get_column_class()
         if self._is_series:
             if (
                 isinstance(key, Series)
@@ -639,7 +635,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
                     self._internal.spark_frame[cast(iLocIndexer, 
self)._sequence_col] < F.lit(limit)
                 )
 
-            if isinstance(value, (Series, Column)):
+            if isinstance(value, (Series, PySparkColumn)):
                 if remaining_index is not None and remaining_index == 0:
                     raise ValueError(
                         "No axis named {} for object type {}".format(key, 
type(value).__name__)
@@ -724,7 +720,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
                     self._internal.spark_frame[cast(iLocIndexer, 
self)._sequence_col] < F.lit(limit)
                 )
 
-            if isinstance(value, (Series, Column)):
+            if isinstance(value, (Series, PySparkColumn)):
                 if remaining_index is not None and remaining_index == 0:
                     raise ValueError("Incompatible indexer with Series")
                 if len(data_spark_columns) > 1:
@@ -1125,9 +1121,8 @@ class LocIndexer(LocIndexerLike):
                     )
                 )[::-1]:
                     compare = 
MultiIndex._comparator_for_monotonic_increasing(dt)
-                    Column = get_column_class()
                     cond = F.when(scol.eqNullSafe(F.lit(value).cast(dt)), 
cond).otherwise(
-                        compare(scol, F.lit(value).cast(dt), Column.__gt__)
+                        compare(scol, F.lit(value).cast(dt), 
PySparkColumn.__gt__)
                     )
                 conds.append(cond)
             if stop is not None:
@@ -1140,9 +1135,8 @@ class LocIndexer(LocIndexerLike):
                     )
                 )[::-1]:
                     compare = 
MultiIndex._comparator_for_monotonic_increasing(dt)
-                    Column = get_column_class()
                     cond = F.when(scol.eqNullSafe(F.lit(value).cast(dt)), 
cond).otherwise(
-                        compare(scol, F.lit(value).cast(dt), Column.__lt__)
+                        compare(scol, F.lit(value).cast(dt), 
PySparkColumn.__lt__)
                     )
                 conds.append(cond)
 
@@ -1300,12 +1294,11 @@ class LocIndexer(LocIndexerLike):
     ]:
         from pyspark.pandas.series import Series
 
-        Column = get_column_class()
         if all(isinstance(key, Series) for key in cols_sel):
             column_labels = [key._column_label for key in cols_sel]
             data_spark_columns = [key.spark.column for key in cols_sel]
             data_fields = [key._internal.data_fields[0] for key in cols_sel]
-        elif all(isinstance(key, Column) for key in cols_sel):
+        elif all(isinstance(key, PySparkColumn) for key in cols_sel):
             column_labels = [
                 (self._internal.spark_frame.select(col).columns[0],) for col 
in cols_sel
             ]
@@ -1804,8 +1797,7 @@ class iLocIndexer(LocIndexerLike):
             )
 
     def __setitem__(self, key: Any, value: Any) -> None:
-        Column = get_column_class()
-        if not isinstance(value, Column) and is_list_like(value):
+        if not isinstance(value, PySparkColumn) and is_list_like(value):
             iloc_item = self[key]
             if not is_list_like(key) or not is_list_like(iloc_item):
                 raise ValueError("setting an array element with a sequence.")
diff --git a/python/pyspark/pandas/internal.py 
b/python/pyspark/pandas/internal.py
index 8ab8d79d5686..04285aa2d879 100644
--- a/python/pyspark/pandas/internal.py
+++ b/python/pyspark/pandas/internal.py
@@ -42,7 +42,7 @@ from pyspark.sql.types import (  # noqa: F401
     StringType,
 )
 from pyspark.sql.utils import is_timestamp_ntz_preferred
-from pyspark.sql.utils import is_remote, get_column_class, get_dataframe_class
+from pyspark.sql.utils import is_remote, get_dataframe_class
 from pyspark import pandas as ps
 from pyspark.pandas._typing import Label
 from pyspark.pandas.spark.utils import as_nullable_spark_type, 
force_decimal_precision_scale
@@ -673,12 +673,12 @@ class InternalFrame:
         self._sdf = spark_frame
 
         # index_spark_columns
-        Column = get_column_class()
+
         assert all(
-            isinstance(index_scol, Column) for index_scol in 
index_spark_columns
+            isinstance(index_scol, PySparkColumn) for index_scol in 
index_spark_columns
         ), index_spark_columns
 
-        self._index_spark_columns: List[Column] = index_spark_columns  # type: 
ignore[valid-type]
+        self._index_spark_columns: List[PySparkColumn] = index_spark_columns
 
         # data_spark_columns
         if data_spark_columns is None:
@@ -692,9 +692,9 @@ class InternalFrame:
                 and col not in HIDDEN_COLUMNS
             ]
         else:
-            assert all(isinstance(scol, Column) for scol in data_spark_columns)
+            assert all(isinstance(scol, PySparkColumn) for scol in 
data_spark_columns)
 
-        self._data_spark_columns: List[Column] = data_spark_columns  # type: 
ignore[valid-type]
+        self._data_spark_columns: List[PySparkColumn] = data_spark_columns
 
         # fields
         if index_fields is None:
@@ -974,27 +974,27 @@ class InternalFrame:
 
     def spark_column_name_for(self, label_or_scol: Union[Label, 
PySparkColumn]) -> str:
         """Return the actual Spark column name for the given column label."""
-        Column = get_column_class()
-        if isinstance(label_or_scol, Column):
+
+        if isinstance(label_or_scol, PySparkColumn):
             return self.spark_frame.select(label_or_scol).columns[0]
         else:
-            return self.field_for(label_or_scol).name  # type: ignore[arg-type]
+            return self.field_for(label_or_scol).name
 
     def spark_type_for(self, label_or_scol: Union[Label, PySparkColumn]) -> 
DataType:
         """Return DataType for the given column label."""
-        Column = get_column_class()
-        if isinstance(label_or_scol, Column):
+
+        if isinstance(label_or_scol, PySparkColumn):
             return self.spark_frame.select(label_or_scol).schema[0].dataType
         else:
-            return self.field_for(label_or_scol).spark_type  # type: 
ignore[arg-type]
+            return self.field_for(label_or_scol).spark_type
 
     def spark_column_nullable_for(self, label_or_scol: Union[Label, 
PySparkColumn]) -> bool:
         """Return nullability for the given column label."""
-        Column = get_column_class()
-        if isinstance(label_or_scol, Column):
+
+        if isinstance(label_or_scol, PySparkColumn):
             return self.spark_frame.select(label_or_scol).schema[0].nullable
         else:
-            return self.field_for(label_or_scol).nullable  # type: 
ignore[arg-type]
+            return self.field_for(label_or_scol).nullable
 
     def field_for(self, label: Label) -> InternalField:
         """Return InternalField for the given column label."""
diff --git a/python/pyspark/pandas/namespace.py 
b/python/pyspark/pandas/namespace.py
index 42a0ce49faa5..4cea4b4fff22 100644
--- a/python/pyspark/pandas/namespace.py
+++ b/python/pyspark/pandas/namespace.py
@@ -94,9 +94,6 @@ from pyspark.pandas.spark.utils import 
as_nullable_spark_type, force_decimal_pre
 from pyspark.pandas.indexes import Index, DatetimeIndex, TimedeltaIndex
 from pyspark.pandas.indexes.multi import MultiIndex
 
-# For Supporting Spark Connect
-from pyspark.sql.utils import get_column_class
-
 __all__ = [
     "from_pandas",
     "range",
@@ -3398,8 +3395,7 @@ def merge_asof(
     else:
         on = None
 
-    Column = get_column_class()
-    if tolerance is not None and not isinstance(tolerance, Column):
+    if tolerance is not None and not isinstance(tolerance, PySparkColumn):
         tolerance = F.lit(tolerance)
 
     as_of_joined_table = left_table._joinAsOf(
@@ -3424,10 +3420,10 @@ def merge_asof(
     data_columns = []
     column_labels = []
 
-    def left_scol_for(label: Label) -> Column:  # type: ignore[valid-type]
+    def left_scol_for(label: Label) -> PySparkColumn:
         return scol_for(as_of_joined_table, 
left_internal.spark_column_name_for(label))
 
-    def right_scol_for(label: Label) -> Column:  # type: ignore[valid-type]
+    def right_scol_for(label: Label) -> PySparkColumn:
         return scol_for(as_of_joined_table, 
right_internal.spark_column_name_for(label))
 
     for label in left_internal.column_labels:
@@ -3441,7 +3437,7 @@ def merge_asof(
                 pass
             else:
                 col = col + left_suffix
-                scol = scol.alias(col)  # type: ignore[attr-defined]
+                scol = scol.alias(col)
                 label = tuple([str(label[0]) + left_suffix] + list(label[1:]))
         exprs.append(scol)
         data_columns.append(col)
@@ -3449,7 +3445,7 @@ def merge_asof(
     for label in right_internal.column_labels:
         # recover `right_prefix` here.
         col = right_internal.spark_column_name_for(label)[len(right_prefix) :]
-        scol = right_scol_for(label).alias(col)  # type: ignore[attr-defined]
+        scol = right_scol_for(label).alias(col)
         if label in duplicate_columns:
             spark_column_name = left_internal.spark_column_name_for(label)
             if spark_column_name in left_as_of_names + left_join_on_names and (
@@ -3458,7 +3454,7 @@ def merge_asof(
                 continue
             else:
                 col = col + right_suffix
-                scol = scol.alias(col)  # type: ignore[attr-defined]
+                scol = scol.alias(col)
                 label = tuple([str(label[0]) + right_suffix] + list(label[1:]))
         exprs.append(scol)
         data_columns.append(col)
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 8edc2c531b51..ac11f327258a 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -70,7 +70,7 @@ from pyspark.sql.types import (
     NullType,
 )
 from pyspark.sql.window import Window
-from pyspark.sql.utils import get_column_class, get_window_class
+from pyspark.sql.utils import get_window_class
 from pyspark import pandas as ps  # For running doctests and reference 
resolution in PyCharm.
 from pyspark.pandas._typing import Axis, Dtype, Label, Name, Scalar, T
 from pyspark.pandas.accessors import PandasOnSparkSeriesMethods
@@ -4171,11 +4171,10 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
         if self._internal.index_level > 1:
             raise NotImplementedError("rank do not support MultiIndex now")
 
-        Column = get_column_class()
         if ascending:
-            asc_func = Column.asc
+            asc_func = PySparkColumn.asc
         else:
-            asc_func = Column.desc
+            asc_func = PySparkColumn.desc
 
         if method == "first":
             window = (
diff --git a/python/pyspark/pandas/spark/accessors.py 
b/python/pyspark/pandas/spark/accessors.py
index cbb0469a474f..b73d24b12d9d 100644
--- a/python/pyspark/pandas/spark/accessors.py
+++ b/python/pyspark/pandas/spark/accessors.py
@@ -27,7 +27,7 @@ from pyspark.sql import Column as PySparkColumn, DataFrame as 
PySparkDataFrame
 from pyspark.sql.types import DataType, StructType
 from pyspark.pandas._typing import IndexOpsLike
 from pyspark.pandas.internal import InternalField
-from pyspark.sql.utils import get_column_class, get_dataframe_class
+from pyspark.sql.utils import get_dataframe_class
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import OptionalPrimitiveType
@@ -116,8 +116,7 @@ class SparkIndexOpsMethods(Generic[IndexOpsLike], 
metaclass=ABCMeta):
         if isinstance(self._data, MultiIndex):
             raise NotImplementedError("MultiIndex does not support 
spark.transform yet.")
         output = func(self._data.spark.column)
-        Column = get_column_class()
-        if not isinstance(output, Column):
+        if not isinstance(output, PySparkColumn):
             raise ValueError(
                 "The output of the function [%s] should be of a "
                 "pyspark.sql.Column; however, got [%s]." % (func, type(output))
@@ -192,8 +191,7 @@ class SparkSeriesMethods(SparkIndexOpsMethods["ps.Series"]):
         from pyspark.pandas.internal import HIDDEN_COLUMNS
 
         output = func(self._data.spark.column)
-        Column = get_column_class()
-        if not isinstance(output, Column):
+        if not isinstance(output, PySparkColumn):
             raise ValueError(
                 "The output of the function [%s] should be of a "
                 "pyspark.sql.Column; however, got [%s]." % (func, type(output))
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 33e01ba378c4..df0451fa1bd2 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -57,7 +57,6 @@ if TYPE_CHECKING:
     from pyspark import SparkContext
     from pyspark.sql.session import SparkSession
     from pyspark.sql.dataframe import DataFrame
-    from pyspark.sql.column import Column
     from pyspark.sql.window import Window
     from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex
 
@@ -412,15 +411,9 @@ def pyspark_column_op(
     Wrapper function for column_op to get proper Column class.
     """
     from pyspark.pandas.base import column_op
-    from pyspark.sql.column import Column as PySparkColumn
+    from pyspark.sql.column import Column
     from pyspark.pandas.data_type_ops.base import _is_extension_dtypes
 
-    if is_remote():
-        from pyspark.sql.connect.column import Column as ConnectColumn
-
-        Column = ConnectColumn
-    else:
-        Column = PySparkColumn  # type: ignore[assignment]
     result = column_op(getattr(Column, func_name))(left, right)
     # It works as expected on extension dtype, so we don't need to call 
`fillna` for this case.
     if (fillna is not None) and (_is_extension_dtypes(left) or 
_is_extension_dtypes(right)):
@@ -429,17 +422,6 @@ def pyspark_column_op(
     return result.fillna(fillna) if fillna is not None else result
 
 
-def get_column_class() -> Type["Column"]:
-    from pyspark.sql.column import Column as PySparkColumn
-
-    if is_remote():
-        from pyspark.sql.connect.column import Column as ConnectColumn
-
-        return ConnectColumn
-    else:
-        return PySparkColumn
-
-
 def get_dataframe_class() -> Type["DataFrame"]:
     from pyspark.sql.dataframe import DataFrame as PySparkDataFrame
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to