This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e2a588c91adc [SPARK-48438][PS][CONNECT] Directly use the parent column
class
e2a588c91adc is described below
commit e2a588c91adc335c34316def0863403330966a30
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed May 29 09:45:34 2024 +0900
[SPARK-48438][PS][CONNECT] Directly use the parent column class
### What changes were proposed in this pull request?
Directly use the parent column class
### Why are the changes needed?
the parent column class works with both Spark Classic and Spark Connect, no
need to use `get_column_class` any more
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #46775 from zhengruifeng/ps_column_cleanup.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/pandas/data_type_ops/base.py | 11 ++---
python/pyspark/pandas/data_type_ops/boolean_ops.py | 16 +++----
python/pyspark/pandas/data_type_ops/date_ops.py | 15 +++----
python/pyspark/pandas/data_type_ops/num_ops.py | 50 ++++++++--------------
python/pyspark/pandas/frame.py | 18 ++++----
python/pyspark/pandas/indexes/multi.py | 11 ++---
python/pyspark/pandas/indexing.py | 24 ++++-------
python/pyspark/pandas/internal.py | 30 ++++++-------
python/pyspark/pandas/namespace.py | 16 +++----
python/pyspark/pandas/series.py | 7 ++-
python/pyspark/pandas/spark/accessors.py | 8 ++--
python/pyspark/sql/utils.py | 20 +--------
12 files changed, 82 insertions(+), 144 deletions(-)
diff --git a/python/pyspark/pandas/data_type_ops/base.py
b/python/pyspark/pandas/data_type_ops/base.py
index 2df40252965b..b4a6b1abbcaf 100644
--- a/python/pyspark/pandas/data_type_ops/base.py
+++ b/python/pyspark/pandas/data_type_ops/base.py
@@ -24,7 +24,7 @@ import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype
-from pyspark.sql import functions as F
+from pyspark.sql import functions as F, Column as PySparkColumn
from pyspark.sql.types import (
ArrayType,
BinaryType,
@@ -53,9 +53,6 @@ from pyspark.pandas.typedef.typehints import (
spark_type_to_pandas_dtype,
)
-# For supporting Spark Connect
-from pyspark.sql.utils import get_column_class
-
if extension_dtypes_available:
from pandas import Int8Dtype, Int16Dtype, Int32Dtype, Int64Dtype
@@ -485,16 +482,14 @@ class DataTypeOps(object, metaclass=ABCMeta):
else:
from pyspark.pandas.base import column_op
- Column = get_column_class()
- return column_op(Column.__eq__)(left, right)
+ return column_op(PySparkColumn.__eq__)(left, right)
def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
_sanitize_list_like(right)
- Column = get_column_class()
- return column_op(Column.__ne__)(left, right)
+ return column_op(PySparkColumn.__ne__)(left, right)
def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
raise TypeError("Unary ~ can not be applied to %s." % self.pretty_name)
diff --git a/python/pyspark/pandas/data_type_ops/boolean_ops.py
b/python/pyspark/pandas/data_type_ops/boolean_ops.py
index 7e7ea7eb0738..c91dcc913080 100644
--- a/python/pyspark/pandas/data_type_ops/boolean_ops.py
+++ b/python/pyspark/pandas/data_type_ops/boolean_ops.py
@@ -35,10 +35,8 @@ from pyspark.pandas.data_type_ops.base import (
_is_boolean_type,
)
from pyspark.pandas.typedef.typehints import as_spark_type, extension_dtypes,
pandas_on_spark_type
-from pyspark.sql import functions as F
-from pyspark.sql.column import Column as PySparkColumn
+from pyspark.sql import functions as F, Column as PySparkColumn
from pyspark.sql.types import BooleanType, StringType
-from pyspark.sql.utils import get_column_class
from pyspark.errors import PySparkValueError
@@ -331,23 +329,19 @@ class BooleanOps(DataTypeOps):
def lt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- Column = get_column_class()
- return column_op(Column.__lt__)(left, right)
+ return column_op(PySparkColumn.__lt__)(left, right)
def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- Column = get_column_class()
- return column_op(Column.__le__)(left, right)
+ return column_op(PySparkColumn.__le__)(left, right)
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- Column = get_column_class()
- return column_op(Column.__ge__)(left, right)
+ return column_op(PySparkColumn.__ge__)(left, right)
def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
- Column = get_column_class()
- return column_op(Column.__gt__)(left, right)
+ return column_op(PySparkColumn.__gt__)(left, right)
def invert(self, operand: IndexOpsLike) -> IndexOpsLike:
return operand._with_new_scol(~operand.spark.column,
field=operand._internal.data_fields[0])
diff --git a/python/pyspark/pandas/data_type_ops/date_ops.py
b/python/pyspark/pandas/data_type_ops/date_ops.py
index 771b5d38a17a..9a0b82de6ce8 100644
--- a/python/pyspark/pandas/data_type_ops/date_ops.py
+++ b/python/pyspark/pandas/data_type_ops/date_ops.py
@@ -23,9 +23,8 @@ import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype
-from pyspark.sql import functions as F
+from pyspark.sql import functions as F, Column as PySparkColumn
from pyspark.sql.types import BooleanType, DateType, StringType
-from pyspark.sql.utils import get_column_class
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
from pyspark.pandas.base import column_op, IndexOpsMixin
from pyspark.pandas.data_type_ops.base import (
@@ -84,29 +83,25 @@ class DateOps(DataTypeOps):
from pyspark.pandas.base import column_op
_sanitize_list_like(right)
- Column = get_column_class()
- return column_op(Column.__lt__)(left, right)
+ return column_op(PySparkColumn.__lt__)(left, right)
def le(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
_sanitize_list_like(right)
- Column = get_column_class()
- return column_op(Column.__le__)(left, right)
+ return column_op(PySparkColumn.__le__)(left, right)
def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
_sanitize_list_like(right)
- Column = get_column_class()
- return column_op(Column.__ge__)(left, right)
+ return column_op(PySparkColumn.__ge__)(left, right)
def gt(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
from pyspark.pandas.base import column_op
_sanitize_list_like(right)
- Column = get_column_class()
- return column_op(Column.__gt__)(left, right)
+ return column_op(PySparkColumn.__gt__)(left, right)
def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype])
-> IndexOpsLike:
dtype, spark_type = pandas_on_spark_type(dtype)
diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py
b/python/pyspark/pandas/data_type_ops/num_ops.py
index 6f393c9652d7..8e8dfee9990e 100644
--- a/python/pyspark/pandas/data_type_ops/num_ops.py
+++ b/python/pyspark/pandas/data_type_ops/num_ops.py
@@ -43,8 +43,7 @@ from pyspark.pandas.data_type_ops.base import (
_is_boolean_type,
)
from pyspark.pandas.typedef.typehints import extension_dtypes,
pandas_on_spark_type
-from pyspark.sql import functions as F
-from pyspark.sql import Column as PySparkColumn
+from pyspark.sql import functions as F, Column as PySparkColumn
from pyspark.sql.types import (
BooleanType,
DataType,
@@ -53,7 +52,7 @@ from pyspark.sql.types import (
from pyspark.errors import PySparkValueError
# For Supporting Spark Connect
-from pyspark.sql.utils import pyspark_column_op, get_column_class
+from pyspark.sql.utils import pyspark_column_op
def _non_fractional_astype(
@@ -82,8 +81,7 @@ class NumericOps(DataTypeOps):
raise TypeError("Addition can not be applied to given types.")
right = transform_boolean_operand_to_numeric(right,
spark_type=left.spark.data_type)
- Column = get_column_class()
- return column_op(Column.__add__)(left, right)
+ return column_op(PySparkColumn.__add__)(left, right)
def sub(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
@@ -91,8 +89,7 @@ class NumericOps(DataTypeOps):
raise TypeError("Subtraction can not be applied to given types.")
right = transform_boolean_operand_to_numeric(right,
spark_type=left.spark.data_type)
- Column = get_column_class()
- return column_op(Column.__sub__)(left, right)
+ return column_op(PySparkColumn.__sub__)(left, right)
def mod(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
@@ -110,13 +107,11 @@ class NumericOps(DataTypeOps):
if not is_valid_operand_for_numeric_arithmetic(right):
raise TypeError("Exponentiation can not be applied to given
types.")
- Column = get_column_class()
-
- def pow_func(left: Column, right: Any) -> Column: # type:
ignore[valid-type]
+ def pow_func(left: PySparkColumn, right: Any) -> PySparkColumn:
return (
- F.when(left == 1, left) # type: ignore
+ F.when(left == 1, left)
.when(F.lit(right) == 0, 1)
- .otherwise(Column.__pow__(left, right))
+ .otherwise(PySparkColumn.__pow__(left, right))
)
right = transform_boolean_operand_to_numeric(right,
spark_type=left.spark.data_type)
@@ -127,34 +122,29 @@ class NumericOps(DataTypeOps):
if not isinstance(right, numbers.Number):
raise TypeError("Addition can not be applied to given types.")
right = transform_boolean_operand_to_numeric(right)
- Column = get_column_class()
- return column_op(Column.__radd__)(left, right)
+ return column_op(PySparkColumn.__radd__)(left, right)
def rsub(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
if not isinstance(right, numbers.Number):
raise TypeError("Subtraction can not be applied to given types.")
right = transform_boolean_operand_to_numeric(right)
- Column = get_column_class()
- return column_op(Column.__rsub__)(left, right)
+ return column_op(PySparkColumn.__rsub__)(left, right)
def rmul(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
if not isinstance(right, numbers.Number):
raise TypeError("Multiplication can not be applied to given
types.")
right = transform_boolean_operand_to_numeric(right)
- Column = get_column_class()
- return column_op(Column.__rmul__)(left, right)
+ return column_op(PySparkColumn.__rmul__)(left, right)
def rpow(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
if not isinstance(right, numbers.Number):
raise TypeError("Exponentiation can not be applied to given
types.")
- Column = get_column_class()
-
- def rpow_func(left: Column, right: Any) -> Column: # type:
ignore[valid-type]
- return F.when(F.lit(right == 1),
right).otherwise(Column.__rpow__(left, right))
+ def rpow_func(left: PySparkColumn, right: Any) -> PySparkColumn:
+ return F.when(F.lit(right == 1),
right).otherwise(PySparkColumn.__rpow__(left, right))
right = transform_boolean_operand_to_numeric(right)
return column_op(rpow_func)(left, right)
@@ -250,8 +240,8 @@ class IntegralOps(NumericOps):
raise TypeError("Multiplication can not be applied to given
types.")
right = transform_boolean_operand_to_numeric(right,
spark_type=left.spark.data_type)
- Column = get_column_class()
- return column_op(Column.__mul__)(left, right)
+
+ return column_op(PySparkColumn.__mul__)(left, right)
def truediv(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
@@ -335,8 +325,8 @@ class FractionalOps(NumericOps):
raise TypeError("Multiplication can not be applied to given
types.")
right = transform_boolean_operand_to_numeric(right,
spark_type=left.spark.data_type)
- Column = get_column_class()
- return column_op(Column.__mul__)(left, right)
+
+ return column_op(PySparkColumn.__mul__)(left, right)
def truediv(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
_sanitize_list_like(right)
@@ -496,13 +486,11 @@ class DecimalOps(FractionalOps):
if not isinstance(right, numbers.Number):
raise TypeError("Exponentiation can not be applied to given
types.")
- Column = get_column_class()
-
- def rpow_func(left: Column, right: Any) -> Column: # type:
ignore[valid-type]
+ def rpow_func(left: PySparkColumn, right: Any) -> PySparkColumn:
return (
- F.when(left.isNull(), np.nan) # type: ignore
+ F.when(left.isNull(), np.nan)
.when(F.lit(right == 1), right)
- .otherwise(Column.__rpow__(left, right))
+ .otherwise(PySparkColumn.__rpow__(left, right))
)
right = transform_boolean_operand_to_numeric(right)
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 3669512d8371..2bf91c2be2f3 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -149,7 +149,7 @@ from pyspark.pandas.typedef.typehints import (
create_tuple_for_frame_type,
)
from pyspark.pandas.plot import PandasOnSparkPlotAccessor
-from pyspark.sql.utils import get_column_class, get_dataframe_class
+from pyspark.sql.utils import get_dataframe_class
if TYPE_CHECKING:
from pyspark.sql._typing import OptionalPrimitiveType
@@ -5629,10 +5629,9 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
from pyspark.pandas.indexes import MultiIndex
from pyspark.pandas.series import IndexOpsMixin
- Column = get_column_class()
for k, v in kwargs.items():
is_invalid_assignee = (
- not (isinstance(v, (IndexOpsMixin, Column)) or callable(v) or
is_scalar(v))
+ not (isinstance(v, (IndexOpsMixin, PySparkColumn)) or
callable(v) or is_scalar(v))
) or isinstance(v, MultiIndex)
if is_invalid_assignee:
raise TypeError(
@@ -5646,7 +5645,7 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
(v.spark.column, v._internal.data_fields[0])
if isinstance(v, IndexOpsMixin) and not isinstance(v,
MultiIndex)
else (v, None)
- if isinstance(v, Column)
+ if isinstance(v, PySparkColumn)
else (F.lit(v), None)
)
for k, v in kwargs.items()
@@ -7689,21 +7688,20 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
if na_position not in ("first", "last"):
raise ValueError("invalid na_position: '{}'".format(na_position))
- Column = get_column_class()
# Mapper: Get a spark colum
# n function for (ascending, na_position) combination
mapper = {
- (True, "first"): Column.asc_nulls_first,
- (True, "last"): Column.asc_nulls_last,
- (False, "first"): Column.desc_nulls_first,
- (False, "last"): Column.desc_nulls_last,
+ (True, "first"): PySparkColumn.asc_nulls_first,
+ (True, "last"): PySparkColumn.asc_nulls_last,
+ (False, "first"): PySparkColumn.desc_nulls_first,
+ (False, "last"): PySparkColumn.desc_nulls_last,
}
by = [mapper[(asc, na_position)](scol) for scol, asc in zip(by,
ascending)]
natural_order_scol = F.col(NATURAL_ORDER_COLUMN_NAME)
if keep == "last":
- natural_order_scol = Column.desc(natural_order_scol)
+ natural_order_scol = PySparkColumn.desc(natural_order_scol)
elif keep == "all":
raise NotImplementedError("`keep`=all is not implemented yet.")
elif keep != "first":
diff --git a/python/pyspark/pandas/indexes/multi.py
b/python/pyspark/pandas/indexes/multi.py
index 7d2712cbb531..b5aae890d50a 100644
--- a/python/pyspark/pandas/indexes/multi.py
+++ b/python/pyspark/pandas/indexes/multi.py
@@ -23,7 +23,6 @@ from pandas.api.types import is_hashable, is_list_like #
type: ignore[attr-defi
from pyspark.sql import functions as F, Column as PySparkColumn, Window
from pyspark.sql.types import DataType
-from pyspark.sql.utils import get_column_class
from pyspark import pandas as ps
from pyspark.pandas._typing import Label, Name, Scalar
from pyspark.pandas.exceptions import PandasNotImplementedError
@@ -514,7 +513,6 @@ class MultiIndex(Index):
cond = F.lit(True)
has_not_null = F.lit(True)
- Column = get_column_class()
for scol in self._internal.index_spark_columns[::-1]:
data_type = self._internal.spark_type_for(scol)
prev = F.lag(scol, 1).over(window)
@@ -522,7 +520,9 @@ class MultiIndex(Index):
# Since pandas 1.1.4, null value is not allowed at any levels of
MultiIndex.
# Therefore, we should check `has_not_null` over all levels.
has_not_null = has_not_null & scol.isNotNull()
- cond = F.when(scol.eqNullSafe(prev), cond).otherwise(compare(scol,
prev, Column.__gt__))
+ cond = F.when(scol.eqNullSafe(prev), cond).otherwise(
+ compare(scol, prev, PySparkColumn.__gt__)
+ )
cond = has_not_null & (prev.isNull() | cond)
@@ -560,7 +560,6 @@ class MultiIndex(Index):
cond = F.lit(True)
has_not_null = F.lit(True)
- Column = get_column_class()
for scol in self._internal.index_spark_columns[::-1]:
data_type = self._internal.spark_type_for(scol)
prev = F.lag(scol, 1).over(window)
@@ -568,7 +567,9 @@ class MultiIndex(Index):
# Since pandas 1.1.4, null value is not allowed at any levels of
MultiIndex.
# Therefore, we should check `has_not_null` over all levels.
has_not_null = has_not_null & scol.isNotNull()
- cond = F.when(scol.eqNullSafe(prev), cond).otherwise(compare(scol,
prev, Column.__lt__))
+ cond = F.when(scol.eqNullSafe(prev), cond).otherwise(
+ compare(scol, prev, PySparkColumn.__lt__)
+ )
cond = has_not_null & (prev.isNull() | cond)
diff --git a/python/pyspark/pandas/indexing.py
b/python/pyspark/pandas/indexing.py
index 24b7c53eea99..fada94cf383a 100644
--- a/python/pyspark/pandas/indexing.py
+++ b/python/pyspark/pandas/indexing.py
@@ -50,7 +50,6 @@ from pyspark.pandas.utils import (
spark_column_equals,
verify_temp_column_name,
)
-from pyspark.sql.utils import get_column_class
if TYPE_CHECKING:
from pyspark.pandas.frame import DataFrame
@@ -259,12 +258,11 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
"""
from pyspark.pandas.series import Series
- Column = get_column_class()
if rows_sel is None:
return None, None, None
elif isinstance(rows_sel, Series):
return self._select_rows_by_series(rows_sel)
- elif isinstance(rows_sel, Column):
+ elif isinstance(rows_sel, PySparkColumn):
return self._select_rows_by_spark_column(rows_sel)
elif isinstance(rows_sel, slice):
if rows_sel == slice(None):
@@ -306,7 +304,6 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
"""
from pyspark.pandas.series import Series
- Column = get_column_class()
if cols_sel is None:
column_labels = self._internal.column_labels
data_spark_columns = self._internal.data_spark_columns
@@ -314,7 +311,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
return column_labels, data_spark_columns, data_fields, False, None
elif isinstance(cols_sel, Series):
return self._select_cols_by_series(cols_sel, missing_keys)
- elif isinstance(cols_sel, Column):
+ elif isinstance(cols_sel, PySparkColumn):
return self._select_cols_by_spark_column(cols_sel, missing_keys)
elif isinstance(cols_sel, slice):
if cols_sel == slice(None):
@@ -579,7 +576,6 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.series import Series, first_series
- Column = get_column_class()
if self._is_series:
if (
isinstance(key, Series)
@@ -639,7 +635,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
self._internal.spark_frame[cast(iLocIndexer,
self)._sequence_col] < F.lit(limit)
)
- if isinstance(value, (Series, Column)):
+ if isinstance(value, (Series, PySparkColumn)):
if remaining_index is not None and remaining_index == 0:
raise ValueError(
"No axis named {} for object type {}".format(key,
type(value).__name__)
@@ -724,7 +720,7 @@ class LocIndexerLike(IndexerLike, metaclass=ABCMeta):
self._internal.spark_frame[cast(iLocIndexer,
self)._sequence_col] < F.lit(limit)
)
- if isinstance(value, (Series, Column)):
+ if isinstance(value, (Series, PySparkColumn)):
if remaining_index is not None and remaining_index == 0:
raise ValueError("Incompatible indexer with Series")
if len(data_spark_columns) > 1:
@@ -1125,9 +1121,8 @@ class LocIndexer(LocIndexerLike):
)
)[::-1]:
compare =
MultiIndex._comparator_for_monotonic_increasing(dt)
- Column = get_column_class()
cond = F.when(scol.eqNullSafe(F.lit(value).cast(dt)),
cond).otherwise(
- compare(scol, F.lit(value).cast(dt), Column.__gt__)
+ compare(scol, F.lit(value).cast(dt),
PySparkColumn.__gt__)
)
conds.append(cond)
if stop is not None:
@@ -1140,9 +1135,8 @@ class LocIndexer(LocIndexerLike):
)
)[::-1]:
compare =
MultiIndex._comparator_for_monotonic_increasing(dt)
- Column = get_column_class()
cond = F.when(scol.eqNullSafe(F.lit(value).cast(dt)),
cond).otherwise(
- compare(scol, F.lit(value).cast(dt), Column.__lt__)
+ compare(scol, F.lit(value).cast(dt),
PySparkColumn.__lt__)
)
conds.append(cond)
@@ -1300,12 +1294,11 @@ class LocIndexer(LocIndexerLike):
]:
from pyspark.pandas.series import Series
- Column = get_column_class()
if all(isinstance(key, Series) for key in cols_sel):
column_labels = [key._column_label for key in cols_sel]
data_spark_columns = [key.spark.column for key in cols_sel]
data_fields = [key._internal.data_fields[0] for key in cols_sel]
- elif all(isinstance(key, Column) for key in cols_sel):
+ elif all(isinstance(key, PySparkColumn) for key in cols_sel):
column_labels = [
(self._internal.spark_frame.select(col).columns[0],) for col
in cols_sel
]
@@ -1804,8 +1797,7 @@ class iLocIndexer(LocIndexerLike):
)
def __setitem__(self, key: Any, value: Any) -> None:
- Column = get_column_class()
- if not isinstance(value, Column) and is_list_like(value):
+ if not isinstance(value, PySparkColumn) and is_list_like(value):
iloc_item = self[key]
if not is_list_like(key) or not is_list_like(iloc_item):
raise ValueError("setting an array element with a sequence.")
diff --git a/python/pyspark/pandas/internal.py
b/python/pyspark/pandas/internal.py
index 8ab8d79d5686..04285aa2d879 100644
--- a/python/pyspark/pandas/internal.py
+++ b/python/pyspark/pandas/internal.py
@@ -42,7 +42,7 @@ from pyspark.sql.types import ( # noqa: F401
StringType,
)
from pyspark.sql.utils import is_timestamp_ntz_preferred
-from pyspark.sql.utils import is_remote, get_column_class, get_dataframe_class
+from pyspark.sql.utils import is_remote, get_dataframe_class
from pyspark import pandas as ps
from pyspark.pandas._typing import Label
from pyspark.pandas.spark.utils import as_nullable_spark_type,
force_decimal_precision_scale
@@ -673,12 +673,12 @@ class InternalFrame:
self._sdf = spark_frame
# index_spark_columns
- Column = get_column_class()
+
assert all(
- isinstance(index_scol, Column) for index_scol in
index_spark_columns
+ isinstance(index_scol, PySparkColumn) for index_scol in
index_spark_columns
), index_spark_columns
- self._index_spark_columns: List[Column] = index_spark_columns # type:
ignore[valid-type]
+ self._index_spark_columns: List[PySparkColumn] = index_spark_columns
# data_spark_columns
if data_spark_columns is None:
@@ -692,9 +692,9 @@ class InternalFrame:
and col not in HIDDEN_COLUMNS
]
else:
- assert all(isinstance(scol, Column) for scol in data_spark_columns)
+ assert all(isinstance(scol, PySparkColumn) for scol in
data_spark_columns)
- self._data_spark_columns: List[Column] = data_spark_columns # type:
ignore[valid-type]
+ self._data_spark_columns: List[PySparkColumn] = data_spark_columns
# fields
if index_fields is None:
@@ -974,27 +974,27 @@ class InternalFrame:
def spark_column_name_for(self, label_or_scol: Union[Label,
PySparkColumn]) -> str:
"""Return the actual Spark column name for the given column label."""
- Column = get_column_class()
- if isinstance(label_or_scol, Column):
+
+ if isinstance(label_or_scol, PySparkColumn):
return self.spark_frame.select(label_or_scol).columns[0]
else:
- return self.field_for(label_or_scol).name # type: ignore[arg-type]
+ return self.field_for(label_or_scol).name
def spark_type_for(self, label_or_scol: Union[Label, PySparkColumn]) ->
DataType:
"""Return DataType for the given column label."""
- Column = get_column_class()
- if isinstance(label_or_scol, Column):
+
+ if isinstance(label_or_scol, PySparkColumn):
return self.spark_frame.select(label_or_scol).schema[0].dataType
else:
- return self.field_for(label_or_scol).spark_type # type:
ignore[arg-type]
+ return self.field_for(label_or_scol).spark_type
def spark_column_nullable_for(self, label_or_scol: Union[Label,
PySparkColumn]) -> bool:
"""Return nullability for the given column label."""
- Column = get_column_class()
- if isinstance(label_or_scol, Column):
+
+ if isinstance(label_or_scol, PySparkColumn):
return self.spark_frame.select(label_or_scol).schema[0].nullable
else:
- return self.field_for(label_or_scol).nullable # type:
ignore[arg-type]
+ return self.field_for(label_or_scol).nullable
def field_for(self, label: Label) -> InternalField:
"""Return InternalField for the given column label."""
diff --git a/python/pyspark/pandas/namespace.py
b/python/pyspark/pandas/namespace.py
index 42a0ce49faa5..4cea4b4fff22 100644
--- a/python/pyspark/pandas/namespace.py
+++ b/python/pyspark/pandas/namespace.py
@@ -94,9 +94,6 @@ from pyspark.pandas.spark.utils import
as_nullable_spark_type, force_decimal_pre
from pyspark.pandas.indexes import Index, DatetimeIndex, TimedeltaIndex
from pyspark.pandas.indexes.multi import MultiIndex
-# For Supporting Spark Connect
-from pyspark.sql.utils import get_column_class
-
__all__ = [
"from_pandas",
"range",
@@ -3398,8 +3395,7 @@ def merge_asof(
else:
on = None
- Column = get_column_class()
- if tolerance is not None and not isinstance(tolerance, Column):
+ if tolerance is not None and not isinstance(tolerance, PySparkColumn):
tolerance = F.lit(tolerance)
as_of_joined_table = left_table._joinAsOf(
@@ -3424,10 +3420,10 @@ def merge_asof(
data_columns = []
column_labels = []
- def left_scol_for(label: Label) -> Column: # type: ignore[valid-type]
+ def left_scol_for(label: Label) -> PySparkColumn:
return scol_for(as_of_joined_table,
left_internal.spark_column_name_for(label))
- def right_scol_for(label: Label) -> Column: # type: ignore[valid-type]
+ def right_scol_for(label: Label) -> PySparkColumn:
return scol_for(as_of_joined_table,
right_internal.spark_column_name_for(label))
for label in left_internal.column_labels:
@@ -3441,7 +3437,7 @@ def merge_asof(
pass
else:
col = col + left_suffix
- scol = scol.alias(col) # type: ignore[attr-defined]
+ scol = scol.alias(col)
label = tuple([str(label[0]) + left_suffix] + list(label[1:]))
exprs.append(scol)
data_columns.append(col)
@@ -3449,7 +3445,7 @@ def merge_asof(
for label in right_internal.column_labels:
# recover `right_prefix` here.
col = right_internal.spark_column_name_for(label)[len(right_prefix) :]
- scol = right_scol_for(label).alias(col) # type: ignore[attr-defined]
+ scol = right_scol_for(label).alias(col)
if label in duplicate_columns:
spark_column_name = left_internal.spark_column_name_for(label)
if spark_column_name in left_as_of_names + left_join_on_names and (
@@ -3458,7 +3454,7 @@ def merge_asof(
continue
else:
col = col + right_suffix
- scol = scol.alias(col) # type: ignore[attr-defined]
+ scol = scol.alias(col)
label = tuple([str(label[0]) + right_suffix] + list(label[1:]))
exprs.append(scol)
data_columns.append(col)
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index 8edc2c531b51..ac11f327258a 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -70,7 +70,7 @@ from pyspark.sql.types import (
NullType,
)
from pyspark.sql.window import Window
-from pyspark.sql.utils import get_column_class, get_window_class
+from pyspark.sql.utils import get_window_class
from pyspark import pandas as ps # For running doctests and reference
resolution in PyCharm.
from pyspark.pandas._typing import Axis, Dtype, Label, Name, Scalar, T
from pyspark.pandas.accessors import PandasOnSparkSeriesMethods
@@ -4171,11 +4171,10 @@ class Series(Frame, IndexOpsMixin, Generic[T]):
if self._internal.index_level > 1:
raise NotImplementedError("rank do not support MultiIndex now")
- Column = get_column_class()
if ascending:
- asc_func = Column.asc
+ asc_func = PySparkColumn.asc
else:
- asc_func = Column.desc
+ asc_func = PySparkColumn.desc
if method == "first":
window = (
diff --git a/python/pyspark/pandas/spark/accessors.py
b/python/pyspark/pandas/spark/accessors.py
index cbb0469a474f..b73d24b12d9d 100644
--- a/python/pyspark/pandas/spark/accessors.py
+++ b/python/pyspark/pandas/spark/accessors.py
@@ -27,7 +27,7 @@ from pyspark.sql import Column as PySparkColumn, DataFrame as
PySparkDataFrame
from pyspark.sql.types import DataType, StructType
from pyspark.pandas._typing import IndexOpsLike
from pyspark.pandas.internal import InternalField
-from pyspark.sql.utils import get_column_class, get_dataframe_class
+from pyspark.sql.utils import get_dataframe_class
if TYPE_CHECKING:
from pyspark.sql._typing import OptionalPrimitiveType
@@ -116,8 +116,7 @@ class SparkIndexOpsMethods(Generic[IndexOpsLike],
metaclass=ABCMeta):
if isinstance(self._data, MultiIndex):
raise NotImplementedError("MultiIndex does not support
spark.transform yet.")
output = func(self._data.spark.column)
- Column = get_column_class()
- if not isinstance(output, Column):
+ if not isinstance(output, PySparkColumn):
raise ValueError(
"The output of the function [%s] should be of a "
"pyspark.sql.Column; however, got [%s]." % (func, type(output))
@@ -192,8 +191,7 @@ class SparkSeriesMethods(SparkIndexOpsMethods["ps.Series"]):
from pyspark.pandas.internal import HIDDEN_COLUMNS
output = func(self._data.spark.column)
- Column = get_column_class()
- if not isinstance(output, Column):
+ if not isinstance(output, PySparkColumn):
raise ValueError(
"The output of the function [%s] should be of a "
"pyspark.sql.Column; however, got [%s]." % (func, type(output))
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 33e01ba378c4..df0451fa1bd2 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -57,7 +57,6 @@ if TYPE_CHECKING:
from pyspark import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.sql.dataframe import DataFrame
- from pyspark.sql.column import Column
from pyspark.sql.window import Window
from pyspark.pandas._typing import IndexOpsLike, SeriesOrIndex
@@ -412,15 +411,9 @@ def pyspark_column_op(
Wrapper function for column_op to get proper Column class.
"""
from pyspark.pandas.base import column_op
- from pyspark.sql.column import Column as PySparkColumn
+ from pyspark.sql.column import Column
from pyspark.pandas.data_type_ops.base import _is_extension_dtypes
- if is_remote():
- from pyspark.sql.connect.column import Column as ConnectColumn
-
- Column = ConnectColumn
- else:
- Column = PySparkColumn # type: ignore[assignment]
result = column_op(getattr(Column, func_name))(left, right)
# It works as expected on extension dtype, so we don't need to call
`fillna` for this case.
if (fillna is not None) and (_is_extension_dtypes(left) or
_is_extension_dtypes(right)):
@@ -429,17 +422,6 @@ def pyspark_column_op(
return result.fillna(fillna) if fillna is not None else result
-def get_column_class() -> Type["Column"]:
- from pyspark.sql.column import Column as PySparkColumn
-
- if is_remote():
- from pyspark.sql.connect.column import Column as ConnectColumn
-
- return ConnectColumn
- else:
- return PySparkColumn
-
-
def get_dataframe_class() -> Type["DataFrame"]:
from pyspark.sql.dataframe import DataFrame as PySparkDataFrame
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]