This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3ba57f5 [SPARK-36951][PYTHON] Inline type hints for python/pyspark/sql/column.py 3ba57f5 is described below commit 3ba57f5edc5594ee676249cd309b8f0d8248462e Author: Xinrong Meng <xinrong.m...@databricks.com> AuthorDate: Tue Oct 12 13:36:22 2021 -0700 [SPARK-36951][PYTHON] Inline type hints for python/pyspark/sql/column.py ### What changes were proposed in this pull request? Inline type hints for python/pyspark/sql/column.py ### Why are the changes needed? Currently, Inline type hints for python/pyspark/sql/column.pyi doesn't support type checking within function bodies. So we inline type hints to support that. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing test. Closes #34226 from xinrong-databricks/inline_column. Authored-by: Xinrong Meng <xinrong.m...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/sql/column.py | 236 ++++++++++++++++++++++++++++---------- python/pyspark/sql/column.pyi | 118 ------------------- python/pyspark/sql/dataframe.py | 12 +- python/pyspark/sql/functions.py | 3 +- python/pyspark/sql/observation.py | 5 +- python/pyspark/sql/window.py | 4 +- 6 files changed, 190 insertions(+), 188 deletions(-) diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index c46b0eb..a3e3e9e 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -18,25 +18,43 @@ import sys import json import warnings +from typing import ( + cast, + overload, + Any, + Callable, + Iterable, + List, + Optional, + Tuple, + TYPE_CHECKING, + Union +) + +from py4j.java_gateway import JavaObject from pyspark import copy_func from pyspark.context import SparkContext from pyspark.sql.types import DataType, StructField, StructType, IntegerType, StringType +if TYPE_CHECKING: + from pyspark.sql._typing import ColumnOrName, LiteralType, DecimalLiteral, DateTimeLiteral + from pyspark.sql.window import WindowSpec + __all__ = ["Column"] -def _create_column_from_literal(literal): - sc = SparkContext._active_spark_context +def _create_column_from_literal(literal: Union["LiteralType", "DecimalLiteral"]) -> "Column": + sc = SparkContext._active_spark_context # type: ignore[attr-defined] return sc._jvm.functions.lit(literal) -def _create_column_from_name(name): - sc = SparkContext._active_spark_context +def _create_column_from_name(name: str) -> "Column": + sc = SparkContext._active_spark_context # type: ignore[attr-defined] return sc._jvm.functions.col(name) -def _to_java_column(col): +def _to_java_column(col: "ColumnOrName") -> JavaObject: if isinstance(col, Column): jcol = col._jc elif isinstance(col, str): @@ -50,7 +68,11 @@ def _to_java_column(col): return jcol -def _to_seq(sc, cols, converter=None): +def _to_seq( + sc: SparkContext, + cols: Iterable["ColumnOrName"], + converter: Optional[Callable[["ColumnOrName"], JavaObject]] = None, +) -> JavaObject: """ Convert a list of Column (or names) into a JVM Seq of Column. @@ -59,10 +81,14 @@ def _to_seq(sc, cols, converter=None): """ if converter: cols = [converter(c) for c in cols] - return sc._jvm.PythonUtils.toSeq(cols) + return sc._jvm.PythonUtils.toSeq(cols) # type: ignore[attr-defined] -def _to_list(sc, cols, converter=None): +def _to_list( + sc: SparkContext, + cols: List["ColumnOrName"], + converter: Optional[Callable[["ColumnOrName"], JavaObject]] = None, +) -> JavaObject: """ Convert a list of Column (or names) into a JVM (Scala) List of Column. @@ -71,30 +97,37 @@ def _to_list(sc, cols, converter=None): """ if converter: cols = [converter(c) for c in cols] - return sc._jvm.PythonUtils.toList(cols) + return sc._jvm.PythonUtils.toList(cols) # type: ignore[attr-defined] -def _unary_op(name, doc="unary operator"): +def _unary_op( + name: str, + doc: str = "unary operator", +) -> Callable[["Column"], "Column"]: """ Create a method for given unary operator """ - def _(self): + def _(self: "Column") -> "Column": jc = getattr(self._jc, name)() return Column(jc) _.__doc__ = doc return _ -def _func_op(name, doc=''): - def _(self): - sc = SparkContext._active_spark_context +def _func_op(name: str, doc: str = '') -> Callable[["Column"], "Column"]: + def _(self: "Column") -> "Column": + sc = SparkContext._active_spark_context # type: ignore[attr-defined] jc = getattr(sc._jvm.functions, name)(self._jc) return Column(jc) _.__doc__ = doc return _ -def _bin_func_op(name, reverse=False, doc="binary function"): - def _(self, other): - sc = SparkContext._active_spark_context +def _bin_func_op( + name: str, + reverse: bool = False, + doc: str = "binary function", +) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"]: + def _(self: "Column", other: Union["Column", "LiteralType", "DecimalLiteral"]) -> "Column": + sc = SparkContext._active_spark_context # type: ignore[attr-defined] fn = getattr(sc._jvm.functions, name) jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other) njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc) @@ -103,10 +136,19 @@ def _bin_func_op(name, reverse=False, doc="binary function"): return _ -def _bin_op(name, doc="binary operator"): +def _bin_op( + name: str, + doc: str = "binary operator", +) -> Callable[ + ["Column", Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"]], + "Column" +]: """ Create a method for given binary operator """ - def _(self, other): + def _( + self: "Column", + other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"], + ) -> "Column": jc = other._jc if isinstance(other, Column) else other njc = getattr(self._jc, name)(jc) return Column(njc) @@ -114,10 +156,13 @@ def _bin_op(name, doc="binary operator"): return _ -def _reverse_op(name, doc="binary operator"): +def _reverse_op( + name: str, + doc: str = "binary operator", +) -> Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"]: """ Create a method for binary operator (this object is on right side) """ - def _(self, other): + def _(self: "Column", other: Union["LiteralType", "DecimalLiteral"]) -> "Column": jother = _create_column_from_literal(other) jc = getattr(jother, name)(self._jc) return Column(jc) @@ -144,29 +189,81 @@ class Column(object): .. versionadded:: 1.3.0 """ - def __init__(self, jc): + def __init__(self, jc: JavaObject) -> None: self._jc = jc # arithmetic operators __neg__ = _func_op("negate") - __add__ = _bin_op("plus") - __sub__ = _bin_op("minus") - __mul__ = _bin_op("multiply") - __div__ = _bin_op("divide") - __truediv__ = _bin_op("divide") - __mod__ = _bin_op("mod") - __radd__ = _bin_op("plus") - __rsub__ = _reverse_op("minus") - __rmul__ = _bin_op("multiply") - __rdiv__ = _reverse_op("divide") - __rtruediv__ = _reverse_op("divide") - __rmod__ = _reverse_op("mod") + __add__ = cast( + Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"], + _bin_op("plus") + ) + __sub__ = cast( + Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"], + _bin_op("minus") + ) + __mul__ = cast( + Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"], + _bin_op("multiply") + ) + __div__ = cast( + Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"], + _bin_op("divide") + ) + __truediv__ = cast( + Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"], + _bin_op("divide") + ) + __mod__ = cast( + Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"], + _bin_op("mod") + ) + __radd__ = cast( + Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"], + _bin_op("plus") + ) + __rsub__ = cast( + Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"], + _reverse_op("minus") + ) + __rmul__ = cast( + Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"], + _bin_op("multiply") + ) + __rdiv__ = cast( + Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"], + _reverse_op("divide") + ) + __rtruediv__ = cast( + Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"], + _reverse_op("divide") + ) + __rmod__ = cast( + Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"], + _reverse_op("mod") + ) + __pow__ = _bin_func_op("pow") - __rpow__ = _bin_func_op("pow", reverse=True) + __rpow__ = cast( + Callable[["Column", Union["LiteralType", "DecimalLiteral"]], "Column"], + _bin_func_op("pow", reverse=True) + ) # logistic operators - __eq__ = _bin_op("equalTo") - __ne__ = _bin_op("notEqual") + def __eq__( # type: ignore[override] + self, + other: Union["Column", "LiteralType", "DecimalLiteral", "DateTimeLiteral"], + ) -> "Column": + """binary function""" + return _bin_op("equalTo")(self, other) + + def __ne__( # type: ignore[override] + self, + other: Any, + ) -> "Column": + """binary function""" + return _bin_op("notEqual")(self, other) + __lt__ = _bin_op("lt") __le__ = _bin_op("leq") __ge__ = _bin_op("geq") @@ -243,7 +340,7 @@ class Column(object): __ror__ = _bin_op("or") # container operators - def __contains__(self, item): + def __contains__(self, item: Any) -> None: raise ValueError("Cannot apply 'in' operator against a column: please use 'contains' " "in a string column or 'array_contains' function for an array column.") @@ -301,7 +398,7 @@ class Column(object): bitwiseAND = _bin_op("bitwiseAND", _bitwiseAND_doc) bitwiseXOR = _bin_op("bitwiseXOR", _bitwiseXOR_doc) - def getItem(self, key): + def getItem(self, key: Any) -> "Column": """ An expression that gets an item at position ``ordinal`` out of a list, or gets an item by key out of a dict. @@ -327,7 +424,7 @@ class Column(object): ) return self[key] - def getField(self, name): + def getField(self, name: Any) -> "Column": """ An expression that gets a field by name in a :class:`StructType`. @@ -359,7 +456,7 @@ class Column(object): ) return self[name] - def withField(self, fieldName, col): + def withField(self, fieldName: str, col: "Column") -> "Column": """ An expression that adds/replaces a field in :class:`StructType` by name. @@ -391,7 +488,7 @@ class Column(object): return Column(self._jc.withField(fieldName, col._jc)) - def dropFields(self, *fieldNames): + def dropFields(self, *fieldNames: str) -> "Column": """ An expression that drops fields in :class:`StructType` by name. This is a no-op if schema doesn't contain field name(s). @@ -441,17 +538,17 @@ class Column(object): +--------------+ """ - sc = SparkContext._active_spark_context + sc = SparkContext._active_spark_context # type: ignore[attr-defined] jc = self._jc.dropFields(_to_seq(sc, fieldNames)) return Column(jc) - def __getattr__(self, item): + def __getattr__(self, item: Any) -> "Column": if item.startswith("__"): raise AttributeError(item) return self[item] - def __getitem__(self, k): + def __getitem__(self, k: Any) -> "Column": if isinstance(k, slice): if k.step is not None: raise ValueError("slice with step is not supported.") @@ -459,7 +556,7 @@ class Column(object): else: return _bin_op("apply")(self, k) - def __iter__(self): + def __iter__(self) -> None: raise TypeError("Column is not iterable") # string methods @@ -565,7 +662,15 @@ class Column(object): startswith = _bin_op("startsWith", _startswith_doc) endswith = _bin_op("endsWith", _endswith_doc) - def substr(self, startPos, length): + @overload + def substr(self, startPos: int, length: int) -> "Column": + ... + + @overload + def substr(self, startPos: "Column", length: "Column") -> "Column": + ... + + def substr(self, startPos: Union[int, "Column"], length: Union[int, "Column"]) -> "Column": """ Return a :class:`Column` which is a substring of the column. @@ -594,12 +699,12 @@ class Column(object): if isinstance(startPos, int): jc = self._jc.substr(startPos, length) elif isinstance(startPos, Column): - jc = self._jc.substr(startPos._jc, length._jc) + jc = self._jc.substr(cast("Column", startPos)._jc, cast("Column", length)._jc) else: raise TypeError("Unexpected type: %s" % type(startPos)) return Column(jc) - def isin(self, *cols): + def isin(self, *cols: Any) -> "Column": """ A boolean expression that is evaluated to true if the value of this expression is contained by the evaluated values of the arguments. @@ -614,9 +719,12 @@ class Column(object): [Row(age=2, name='Alice')] """ if len(cols) == 1 and isinstance(cols[0], (list, set)): - cols = cols[0] - cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] - sc = SparkContext._active_spark_context + cols = cast(Tuple, cols[0]) + cols = cast( + Tuple, + [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] + ) + sc = SparkContext._active_spark_context # type: ignore[attr-defined] jc = getattr(self._jc, "isin")(_to_seq(sc, cols)) return Column(jc) @@ -730,7 +838,7 @@ class Column(object): isNull = _unary_op("isNull", _isNull_doc) isNotNull = _unary_op("isNotNull", _isNotNull_doc) - def alias(self, *alias, **kwargs): + def alias(self, *alias: str, **kwargs: Any) -> "Column": """ Returns this column aliased with a new name or names (in the case of expressions that return more than one column, such as explode). @@ -763,7 +871,7 @@ class Column(object): metadata = kwargs.pop('metadata', None) assert not kwargs, 'Unexpected kwargs where passed: %s' % kwargs - sc = SparkContext._active_spark_context + sc = SparkContext._active_spark_context # type: ignore[attr-defined] if len(alias) == 1: if metadata: jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson( @@ -778,7 +886,7 @@ class Column(object): name = copy_func(alias, sinceversion=2.0, doc=":func:`name` is an alias for :func:`alias`.") - def cast(self, dataType): + def cast(self, dataType: Union[DataType, str]) -> "Column": """ Casts the column into type ``dataType``. @@ -804,7 +912,11 @@ class Column(object): astype = copy_func(cast, sinceversion=1.4, doc=":func:`astype` is an alias for :func:`cast`.") - def between(self, lowerBound, upperBound): + def between( + self, + lowerBound: Union["Column", "LiteralType", "DateTimeLiteral", "DecimalLiteral"], + upperBound: Union["Column", "LiteralType", "DateTimeLiteral", "DecimalLiteral"], + ) -> "Column": """ True if the current column is between the lower bound and upper bound, inclusive. @@ -822,7 +934,7 @@ class Column(object): """ return (self >= lowerBound) & (self <= upperBound) - def when(self, condition, value): + def when(self, condition: "Column", value: Any) -> "Column": """ Evaluates a list of conditions and returns one of multiple possible result expressions. If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. @@ -857,7 +969,7 @@ class Column(object): jc = self._jc.when(condition._jc, v) return Column(jc) - def otherwise(self, value): + def otherwise(self, value: Any) -> "Column": """ Evaluates a list of conditions and returns one of multiple possible result expressions. If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. @@ -888,7 +1000,7 @@ class Column(object): jc = self._jc.otherwise(v) return Column(jc) - def over(self, window): + def over(self, window: "WindowSpec") -> "Column": """ Define a windowing column. @@ -924,16 +1036,16 @@ class Column(object): jc = self._jc.over(window._jspec) return Column(jc) - def __nonzero__(self): + def __nonzero__(self) -> None: raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', " "'~' for 'not' when building DataFrame boolean expressions.") __bool__ = __nonzero__ - def __repr__(self): + def __repr__(self) -> str: return "Column<'%s'>" % self._jc.toString() -def _test(): +def _test() -> None: import doctest from pyspark.sql import SparkSession import pyspark.sql.column diff --git a/python/pyspark/sql/column.pyi b/python/pyspark/sql/column.pyi deleted file mode 100644 index 36c1bcc..0000000 --- a/python/pyspark/sql/column.pyi +++ /dev/null @@ -1,118 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import Any, Union - -from pyspark.sql._typing import LiteralType, DecimalLiteral, DateTimeLiteral -from pyspark.sql.types import ( # noqa: F401 - DataType, - StructField, - StructType, - IntegerType, - StringType, -) -from pyspark.sql.window import WindowSpec - -from py4j.java_gateway import JavaObject # type: ignore[import] - -class Column: - def __init__(self, jc: JavaObject) -> None: ... - def __neg__(self) -> Column: ... - def __add__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... - def __sub__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... - def __mul__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... - def __div__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... - def __truediv__( - self, other: Union[Column, LiteralType, DecimalLiteral] - ) -> Column: ... - def __mod__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... - def __radd__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ... - def __rsub__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ... - def __rmul__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ... - def __rdiv__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ... - def __rtruediv__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ... - def __rmod__(self, other: Union[bool, int, float, DecimalLiteral]) -> Column: ... - def __pow__(self, other: Union[Column, LiteralType, DecimalLiteral]) -> Column: ... - def __rpow__(self, other: Union[LiteralType, DecimalLiteral]) -> Column: ... - def __eq__(self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral]) -> Column: ... # type: ignore[override] - def __ne__(self, other: Any) -> Column: ... # type: ignore[override] - def __lt__( - self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral] - ) -> Column: ... - def __le__( - self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral] - ) -> Column: ... - def __ge__( - self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral] - ) -> Column: ... - def __gt__( - self, other: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral] - ) -> Column: ... - def eqNullSafe( - self, other: Union[Column, LiteralType, DecimalLiteral] - ) -> Column: ... - def __and__(self, other: Column) -> Column: ... - def __or__(self, other: Column) -> Column: ... - def __invert__(self) -> Column: ... - def __rand__(self, other: Column) -> Column: ... - def __ror__(self, other: Column) -> Column: ... - def __contains__(self, other: Any) -> Column: ... - def __getitem__(self, other: Any) -> Column: ... - def bitwiseOR(self, other: Union[Column, int]) -> Column: ... - def bitwiseAND(self, other: Union[Column, int]) -> Column: ... - def bitwiseXOR(self, other: Union[Column, int]) -> Column: ... - def getItem(self, key: Any) -> Column: ... - def getField(self, name: Any) -> Column: ... - def withField(self, fieldName: str, col: Column) -> Column: ... - def dropFields(self, *fieldNames: str) -> Column: ... - def __getattr__(self, item: Any) -> Column: ... - def __iter__(self) -> None: ... - def rlike(self, item: str) -> Column: ... - def like(self, item: str) -> Column: ... - def startswith(self, item: Union[str, Column]) -> Column: ... - def endswith(self, item: Union[str, Column]) -> Column: ... - @overload - def substr(self, startPos: int, length: int) -> Column: ... - @overload - def substr(self, startPos: Column, length: Column) -> Column: ... - def __getslice__(self, startPos: int, length: int) -> Column: ... - def isin(self, *cols: Any) -> Column: ... - def asc(self) -> Column: ... - def asc_nulls_first(self) -> Column: ... - def asc_nulls_last(self) -> Column: ... - def desc(self) -> Column: ... - def desc_nulls_first(self) -> Column: ... - def desc_nulls_last(self) -> Column: ... - def isNull(self) -> Column: ... - def isNotNull(self) -> Column: ... - def alias(self, *alias: str, **kwargs: Any) -> Column: ... - def name(self, *alias: str) -> Column: ... - def cast(self, dataType: Union[DataType, str]) -> Column: ... - def astype(self, dataType: Union[DataType, str]) -> Column: ... - def between( - self, - lowerBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral], - upperBound: Union[Column, LiteralType, DateTimeLiteral, DecimalLiteral], - ) -> Column: ... - def when(self, condition: Column, value: Any) -> Column: ... - def otherwise(self, value: Any) -> Column: ... - def over(self, window: WindowSpec) -> Column: ... - def __nonzero__(self) -> None: ... - def __bool__(self) -> None: ... - def contains(self, item: Any) -> Column: ... diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 339f8f8..223f041 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1279,7 +1279,11 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): raise ValueError("Weights must be positive. Found weight value: %s" % w) seed = seed if seed is not None else random.randint(0, sys.maxsize) rdd_array = self._jdf.randomSplit( - _to_list(self.sql_ctx._sc, weights), int(seed) # type: ignore[attr-defined] + _to_list( + self.sql_ctx._sc, # type: ignore[attr-defined] + cast(List["ColumnOrName"], weights) + ), + int(seed) ) return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] @@ -1674,7 +1678,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): raise ValueError("should sort by at least one column") if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] - jcols = [_to_java_column(c) for c in cols] + jcols = [_to_java_column(cast("ColumnOrName", c)) for c in cols] ascending = kwargs.get('ascending', True) if isinstance(ascending, (bool, int)): if not ascending: @@ -2723,7 +2727,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): for c in col: if not isinstance(c, str): raise TypeError("columns should be strings, but got %r" % type(c)) - col = _to_list(self._sc, col) + col = _to_list(self._sc, cast(List["ColumnOrName"], col)) if not isinstance(probabilities, (list, tuple)): raise TypeError("probabilities should be a list or tuple") @@ -2732,7 +2736,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): for p in probabilities: if not isinstance(p, (float, int)) or p < 0 or p > 1: raise ValueError("probabilities should be numerical (float, int) in [0,1].") - probabilities = _to_list(self._sc, probabilities) + probabilities = _to_list(self._sc, cast(List["ColumnOrName"], probabilities)) if not isinstance(relativeError, (float, int)): raise TypeError("relativeError should be numerical (float, int)") diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7e0d015..717eaec 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -24,6 +24,7 @@ import functools import warnings from typing import ( Any, + cast, Callable, Dict, List, @@ -1770,7 +1771,7 @@ def log(arg1: Union["ColumnOrName", float], arg2: Optional["ColumnOrName"] = Non """ sc = SparkContext._active_spark_context # type: ignore[attr-defined] if arg2 is None: - jc = sc._jvm.functions.log(_to_java_column(arg1)) + jc = sc._jvm.functions.log(_to_java_column(cast("ColumnOrName", arg1))) else: jc = sc._jvm.functions.log(arg1, _to_java_column(arg2)) return Column(jc) diff --git a/python/pyspark/sql/observation.py b/python/pyspark/sql/observation.py index 48d8176..f60e580 100644 --- a/python/pyspark/sql/observation.py +++ b/python/pyspark/sql/observation.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, TYPE_CHECKING from pyspark.sql import column from pyspark.sql.column import Column @@ -22,6 +22,9 @@ from pyspark.sql.dataframe import DataFrame __all__ = ["Observation"] +if TYPE_CHECKING: + from pyspark import SparkContext # noqa: F401 + class Observation: """Class to observe (named) metrics on a :class:`DataFrame`. diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index 3054273..f1b03ab 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -16,7 +16,7 @@ # import sys -from typing import List, Tuple, TYPE_CHECKING, Union +from typing import cast, Iterable, List, Tuple, TYPE_CHECKING, Union from pyspark import since, SparkContext from pyspark.sql.column import _to_seq, _to_java_column # type: ignore[attr-defined] @@ -35,7 +35,7 @@ def _to_java_cols( sc = SparkContext._active_spark_context # type: ignore[attr-defined] if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] # type: ignore[assignment] - return _to_seq(sc, cols, _to_java_column) + return _to_seq(sc, cast(Iterable["ColumnOrName"], cols), _to_java_column) class Window(object): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org