This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9995e62 [SPARK-36885][PYTHON] Inline type hints for pyspark.sql.dataframe 9995e62 is described below commit 9995e623f7f65fdd3b1dc3cd4e0140a7cf4bc4a0 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Tue Oct 12 09:17:14 2021 +0900 [SPARK-36885][PYTHON] Inline type hints for pyspark.sql.dataframe ### What changes were proposed in this pull request? Inline type hints from `python/pyspark/sql/dataframe.pyi` to `python/pyspark/sql/dataframe.py`. ### Why are the changes needed? Currently, there is type hint stub files `python/pyspark/sql/dataframe.pyi` to show the expected types for functions, but we can also take advantage of static type checking within the functions by inlining the type hints. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #34225 from ueshin/issues/SPARK-36885/inline_typehints_dataframe. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/__init__.pyi | 2 +- python/pyspark/sql/dataframe.py | 671 ++++++++++++++++++++++++++++---------- python/pyspark/sql/dataframe.pyi | 351 -------------------- python/pyspark/sql/observation.py | 2 +- 4 files changed, 504 insertions(+), 522 deletions(-) diff --git a/python/pyspark/__init__.pyi b/python/pyspark/__init__.pyi index f85319b..35df545 100644 --- a/python/pyspark/__init__.pyi +++ b/python/pyspark/__init__.pyi @@ -71,7 +71,7 @@ def since(version: Union[str, float]) -> Callable[[T], T]: ... def copy_func( f: F, name: Optional[str] = ..., - sinceversion: Optional[str] = ..., + sinceversion: Optional[Union[str, float]] = ..., doc: Optional[str] = ..., ) -> F: ... def keyword_only(func: F) -> F: ... diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8d4c94f..339f8f8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -22,22 +22,41 @@ import warnings from collections.abc import Iterable from functools import reduce from html import escape as html_escape +from typing import ( + Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union, + cast, overload, TYPE_CHECKING +) -from pyspark import copy_func, since, _NoValue +from py4j.java_gateway import JavaObject # type: ignore[import] + +from pyspark import copy_func, since, _NoValue # type: ignore[attr-defined] from pyspark.context import SparkContext -from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket +from pyspark.rdd import ( # type: ignore[attr-defined] + RDD, _load_from_socket, _local_iterator_from_socket +) from pyspark.serializers import BatchedSerializer, PickleSerializer, \ UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync -from pyspark.sql.types import _parse_datatype_json_string -from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column +from pyspark.sql.types import _parse_datatype_json_string # type: ignore[attr-defined] +from pyspark.sql.column import ( # type: ignore[attr-defined] + Column, _to_seq, _to_list, _to_java_column +) from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2 from pyspark.sql.streaming import DataStreamWriter -from pyspark.sql.types import StructType, StructField, StringType, IntegerType +from pyspark.sql.types import StructType, StructField, StringType, IntegerType, Row from pyspark.sql.pandas.conversion import PandasConversionMixin from pyspark.sql.pandas.map_ops import PandasMapOpsMixin +if TYPE_CHECKING: + from pyspark._typing import PrimitiveType + from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame + from pyspark.sql._typing import ColumnOrName, LiteralType, OptionalPrimitiveType + from pyspark.sql.context import SQLContext + from pyspark.sql.group import GroupedData + from pyspark.sql.observation import Observation + + __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] @@ -68,42 +87,49 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): .. versionadded:: 1.3.0 """ - def __init__(self, jdf, sql_ctx): + def __init__(self, jdf: JavaObject, sql_ctx: "SQLContext"): self._jdf = jdf self.sql_ctx = sql_ctx - self._sc = sql_ctx and sql_ctx._sc + self._sc = cast( + SparkContext, + sql_ctx and sql_ctx._sc # type: ignore[attr-defined] + ) self.is_cached = False - self._schema = None # initialized lazily - self._lazy_rdd = None + # initialized lazily + self._schema: Optional[StructType] = None + self._lazy_rdd: Optional[RDD[Row]] = None # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice # by __repr__ and _repr_html_ while eager evaluation opened. self._support_repr_html = False - @property + @property # type: ignore[misc] @since(1.3) - def rdd(self): + def rdd(self) -> "RDD[Row]": """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. """ if self._lazy_rdd is None: jrdd = self._jdf.javaToPython() - self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + self._lazy_rdd = RDD( + jrdd, self.sql_ctx._sc, # type: ignore[attr-defined] + BatchedSerializer(PickleSerializer()) + ) return self._lazy_rdd - @property + @property # type: ignore[misc] @since("1.3.1") - def na(self): + def na(self) -> "DataFrameNaFunctions": """Returns a :class:`DataFrameNaFunctions` for handling missing values. """ return DataFrameNaFunctions(self) - @property + @property # type: ignore[misc] @since(1.4) - def stat(self): + def stat(self) -> "DataFrameStatFunctions": """Returns a :class:`DataFrameStatFunctions` for statistic functions. """ return DataFrameStatFunctions(self) - def toJSON(self, use_unicode=True): + def toJSON(self, use_unicode: bool = True) -> "RDD[str]": """Converts a :class:`DataFrame` into a :class:`RDD` of string. Each row is turned into a JSON document as one element in the returned RDD. @@ -118,7 +144,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - def registerTempTable(self, name): + def registerTempTable(self, name: str) -> None: """Registers this :class:`DataFrame` as a temporary table using the given name. The lifetime of this temporary table is tied to the :class:`SparkSession` @@ -145,7 +171,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): ) self._jdf.createOrReplaceTempView(name) - def createTempView(self, name): + def createTempView(self, name: str) -> None: """Creates a local temporary view with this :class:`DataFrame`. The lifetime of this temporary table is tied to the :class:`SparkSession` @@ -171,7 +197,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ self._jdf.createTempView(name) - def createOrReplaceTempView(self, name): + def createOrReplaceTempView(self, name: str) -> None: """Creates or replaces a local temporary view with this :class:`DataFrame`. The lifetime of this temporary table is tied to the :class:`SparkSession` @@ -193,7 +219,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ self._jdf.createOrReplaceTempView(name) - def createGlobalTempView(self, name): + def createGlobalTempView(self, name: str) -> None: """Creates a global temporary view with this :class:`DataFrame`. The lifetime of this temporary view is tied to this Spark application. @@ -218,7 +244,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ self._jdf.createGlobalTempView(name) - def createOrReplaceGlobalTempView(self, name): + def createOrReplaceGlobalTempView(self, name: str) -> None: """Creates or replaces a global temporary view using the given name. The lifetime of this temporary view is tied to this Spark application. @@ -240,7 +266,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): self._jdf.createOrReplaceGlobalTempView(name) @property - def write(self): + def write(self) -> DataFrameWriter: """ Interface for saving the content of the non-streaming :class:`DataFrame` out into external storage. @@ -254,7 +280,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return DataFrameWriter(self) @property - def writeStream(self): + def writeStream(self) -> DataStreamWriter: """ Interface for saving the content of the streaming :class:`DataFrame` out into external storage. @@ -272,7 +298,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return DataStreamWriter(self) @property - def schema(self): + def schema(self) -> StructType: """Returns the schema of this :class:`DataFrame` as a :class:`pyspark.sql.types.StructType`. .. versionadded:: 1.3.0 @@ -290,7 +316,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): "Unable to parse datatype from schema. %s" % e) from e return self._schema - def printSchema(self): + def printSchema(self) -> None: """Prints out the schema in the tree format. .. versionadded:: 1.3.0 @@ -305,7 +331,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ print(self._jdf.schema().treeString()) - def explain(self, extended=None, mode=None): + def explain( + self, extended: Optional[Union[bool, str]] = None, mode: Optional[str] = None + ) -> None: """Prints the (logical and physical) plans to the console for debugging purpose. .. versionadded:: 1.3.0 @@ -390,13 +418,16 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): elif is_extended_case: explain_mode = "extended" if extended else "simple" elif is_mode_case: - explain_mode = mode + explain_mode = cast(str, mode) elif is_extended_as_mode: - explain_mode = extended + explain_mode = cast(str, extended) - print(self._sc._jvm.PythonSQLUtils.explainString(self._jdf.queryExecution(), explain_mode)) + print( + self._sc._jvm # type: ignore[attr-defined] + .PythonSQLUtils.explainString(self._jdf.queryExecution(), explain_mode) + ) - def exceptAll(self, other): + def exceptAll(self, other: "DataFrame") -> "DataFrame": """Return a new :class:`DataFrame` containing rows in this :class:`DataFrame` but not in another :class:`DataFrame` while preserving duplicates. @@ -425,14 +456,14 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx) @since(1.3) - def isLocal(self): + def isLocal(self) -> bool: """Returns ``True`` if the :func:`collect` and :func:`take` methods can be run locally (without any Spark executors). """ return self._jdf.isLocal() @property - def isStreaming(self): + def isStreaming(self) -> bool: """Returns ``True`` if this :class:`DataFrame` contains one or more sources that continuously return data as it arrives. A :class:`DataFrame` that reads data from a streaming source must be executed as a :class:`StreamingQuery` using the :func:`start` @@ -448,7 +479,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return self._jdf.isStreaming() - def show(self, n=20, truncate=True, vertical=False): + def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: """Prints the first ``n`` rows to the console. .. versionadded:: 1.3.0 @@ -509,26 +540,33 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): print(self._jdf.showString(n, int_truncate, vertical)) - def __repr__(self): - if not self._support_repr_html and self.sql_ctx._conf.isReplEagerEvalEnabled(): + def __repr__(self) -> str: + if ( + not self._support_repr_html + and self.sql_ctx._conf.isReplEagerEvalEnabled() # type: ignore[attr-defined] + ): vertical = False return self._jdf.showString( - self.sql_ctx._conf.replEagerEvalMaxNumRows(), - self.sql_ctx._conf.replEagerEvalTruncate(), vertical) + self.sql_ctx._conf.replEagerEvalMaxNumRows(), # type: ignore[attr-defined] + self.sql_ctx._conf.replEagerEvalTruncate(), vertical) # type: ignore[attr-defined] else: return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) - def _repr_html_(self): + def _repr_html_(self) -> Optional[str]: """Returns a :class:`DataFrame` with html code when you enabled eager evaluation by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are using support eager evaluation with HTML. """ if not self._support_repr_html: self._support_repr_html = True - if self.sql_ctx._conf.isReplEagerEvalEnabled(): - max_num_rows = max(self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0) + if self.sql_ctx._conf.isReplEagerEvalEnabled(): # type: ignore[attr-defined] + max_num_rows = max( + self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0 # type: ignore[attr-defined] + ) sock_info = self._jdf.getRowsToPython( - max_num_rows, self.sql_ctx._conf.replEagerEvalTruncate()) + max_num_rows, + self.sql_ctx._conf.replEagerEvalTruncate() # type: ignore[attr-defined] + ) rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) head = rows[0] row_data = rows[1:] @@ -550,7 +588,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): else: return None - def checkpoint(self, eager=True): + def checkpoint(self, eager: bool = True) -> "DataFrame": """Returns a checkpointed version of this :class:`DataFrame`. Checkpointing can be used to truncate the logical plan of this :class:`DataFrame`, which is especially useful in iterative algorithms where the plan may grow exponentially. It will be saved to files @@ -570,7 +608,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jdf = self._jdf.checkpoint(eager) return DataFrame(jdf, self.sql_ctx) - def localCheckpoint(self, eager=True): + def localCheckpoint(self, eager: bool = True) -> "DataFrame": """Returns a locally checkpointed version of this :class:`DataFrame`. Checkpointing can be used to truncate the logical plan of this :class:`DataFrame`, which is especially useful in iterative algorithms where the plan may grow exponentially. Local checkpoints are @@ -590,7 +628,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jdf = self._jdf.localCheckpoint(eager) return DataFrame(jdf, self.sql_ctx) - def withWatermark(self, eventTime, delayThreshold): + def withWatermark(self, eventTime: str, delayThreshold: str) -> "DataFrame": """Defines an event time watermark for this :class:`DataFrame`. A watermark tracks a point in time before which we assume no more late data is going to arrive. @@ -634,7 +672,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jdf = self._jdf.withWatermark(eventTime, delayThreshold) return DataFrame(jdf, self.sql_ctx) - def hint(self, name, *parameters): + def hint( + self, name: str, *parameters: Union["PrimitiveType", List["PrimitiveType"]] + ) -> "DataFrame": """Specifies some hint on the current :class:`DataFrame`. .. versionadded:: 2.2.0 @@ -660,7 +700,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): +----+---+------+ """ if len(parameters) == 1 and isinstance(parameters[0], list): - parameters = parameters[0] + parameters = parameters[0] # type: ignore[assignment] if not isinstance(name, str): raise TypeError("name should be provided as str, got {0}".format(type(name))) @@ -675,7 +715,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jdf = self._jdf.hint(name, self._jseq(parameters)) return DataFrame(jdf, self.sql_ctx) - def count(self): + def count(self) -> int: """Returns the number of rows in this :class:`DataFrame`. .. versionadded:: 1.3.0 @@ -687,7 +727,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return int(self._jdf.count()) - def collect(self): + def collect(self) -> List[Row]: """Returns all the records as a list of :class:`Row`. .. versionadded:: 1.3.0 @@ -701,7 +741,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): sock_info = self._jdf.collectToPython() return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) - def toLocalIterator(self, prefetchPartitions=False): + def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]: """ Returns an iterator that contains all of the rows in this :class:`DataFrame`. The iterator will consume as much memory as the largest partition in this @@ -724,7 +764,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): sock_info = self._jdf.toPythonIterator(prefetchPartitions) return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer())) - def limit(self, num): + def limit(self, num: int) -> "DataFrame": """Limits the result count to the number specified. .. versionadded:: 1.3.0 @@ -739,7 +779,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jdf = self._jdf.limit(num) return DataFrame(jdf, self.sql_ctx) - def take(self, num): + def take(self, num: int) -> List[Row]: """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. .. versionadded:: 1.3.0 @@ -751,7 +791,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return self.limit(num).collect() - def tail(self, num): + def tail(self, num: int) -> List[Row]: """ Returns the last ``num`` rows as a :class:`list` of :class:`Row`. @@ -769,7 +809,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): sock_info = self._jdf.tailToPython(num) return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) - def foreach(self, f): + def foreach(self, f: Callable[[Row], None]) -> None: """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. This is a shorthand for ``df.rdd.foreach()``. @@ -784,7 +824,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ self.rdd.foreach(f) - def foreachPartition(self, f): + def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: """Applies the ``f`` function to each partition of this :class:`DataFrame`. This a shorthand for ``df.rdd.foreachPartition()``. @@ -798,9 +838,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): ... print(person.name) >>> df.foreachPartition(f) """ - self.rdd.foreachPartition(f) + self.rdd.foreachPartition(f) # type: ignore[arg-type] - def cache(self): + def cache(self) -> "DataFrame": """Persists the :class:`DataFrame` with the default storage level (`MEMORY_AND_DISK`). .. versionadded:: 1.3.0 @@ -813,7 +853,12 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): self._jdf.cache() return self - def persist(self, storageLevel=StorageLevel.MEMORY_AND_DISK_DESER): + def persist( + self, + storageLevel: StorageLevel = ( + StorageLevel.MEMORY_AND_DISK_DESER # type: ignore[attr-defined] + ) + ) -> "DataFrame": """Sets the storage level to persist the contents of the :class:`DataFrame` across operations after the first time it is computed. This can only be used to assign a new storage level if the :class:`DataFrame` does not have a storage level set yet. @@ -826,12 +871,12 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): The default storage level has changed to `MEMORY_AND_DISK_DESER` to match Scala in 3.0. """ self.is_cached = True - javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) # type: ignore[attr-defined] self._jdf.persist(javaStorageLevel) return self @property - def storageLevel(self): + def storageLevel(self) -> StorageLevel: """Get the :class:`DataFrame`'s current storage level. .. versionadded:: 2.1.0 @@ -853,7 +898,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): java_storage_level.replication()) return storage_level - def unpersist(self, blocking=False): + def unpersist(self, blocking: bool = False) -> "DataFrame": """Marks the :class:`DataFrame` as non-persistent, and remove all blocks for it from memory and disk. @@ -867,7 +912,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): self._jdf.unpersist(blocking) return self - def coalesce(self, numPartitions): + def coalesce(self, numPartitions: int) -> "DataFrame": """ Returns a new :class:`DataFrame` that has exactly `numPartitions` partitions. @@ -898,7 +943,17 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx) - def repartition(self, numPartitions, *cols): + @overload + def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> "DataFrame": + ... + + @overload + def repartition(self, *cols: "ColumnOrName") -> "DataFrame": + ... + + def repartition( # type: ignore[misc] + self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName" + ) -> "DataFrame": """ Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The resulting :class:`DataFrame` is hash partitioned. @@ -967,7 +1022,17 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): else: raise TypeError("numPartitions should be an int or Column") - def repartitionByRange(self, numPartitions, *cols): + @overload + def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> "DataFrame": + ... + + @overload + def repartitionByRange(self, *cols: "ColumnOrName") -> "DataFrame": + ... + + def repartitionByRange( # type: ignore[misc] + self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName" + ) -> "DataFrame": """ Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The resulting :class:`DataFrame` is range partitioned. @@ -1017,7 +1082,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ if isinstance(numPartitions, int): if len(cols) == 0: - return ValueError("At least one partition-by expression must be specified.") + raise ValueError("At least one partition-by expression must be specified.") else: return DataFrame( self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)), self.sql_ctx) @@ -1027,7 +1092,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): else: raise TypeError("numPartitions should be an int, string or Column") - def distinct(self): + def distinct(self) -> "DataFrame": """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. .. versionadded:: 1.3.0 @@ -1039,7 +1104,25 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return DataFrame(self._jdf.distinct(), self.sql_ctx) - def sample(self, withReplacement=None, fraction=None, seed=None): + @overload + def sample(self, fraction: float, seed: Optional[int] = ...) -> "DataFrame": + ... + + @overload + def sample( + self, + withReplacement: Optional[bool], + fraction: float, + seed: Optional[int] = ..., + ) -> "DataFrame": + ... + + def sample( # type: ignore[misc] + self, + withReplacement: Optional[Union[float, bool]] = None, + fraction: Optional[Union[int, float]] = None, + seed: Optional[int] = None + ) -> "DataFrame": """Returns a sampled subset of this :class:`DataFrame`. .. versionadded:: 1.3.0 @@ -1105,7 +1188,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): if is_withReplacement_omitted_args: if fraction is not None: - seed = fraction + seed = cast(int, fraction) fraction = withReplacement withReplacement = None @@ -1114,7 +1197,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jdf = self._jdf.sample(*args) return DataFrame(jdf, self.sql_ctx) - def sampleBy(self, col, fractions, seed=None): + def sampleBy( + self, col: "ColumnOrName", fractions: Dict[Any, float], seed: Optional[int] = None + ) -> "DataFrame": """ Returns a stratified sample without replacement based on the fraction given on each stratum. @@ -1167,7 +1252,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): seed = seed if seed is not None else random.randint(0, sys.maxsize) return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) - def randomSplit(self, weights, seed=None): + def randomSplit(self, weights: List[float], seed: Optional[int] = None) -> List["DataFrame"]: """Randomly splits this :class:`DataFrame` with the provided weights. .. versionadded:: 1.4.0 @@ -1193,11 +1278,13 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): if w < 0.0: raise ValueError("Weights must be positive. Found weight value: %s" % w) seed = seed if seed is not None else random.randint(0, sys.maxsize) - rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights), int(seed)) + rdd_array = self._jdf.randomSplit( + _to_list(self.sql_ctx._sc, weights), int(seed) # type: ignore[attr-defined] + ) return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array] @property - def dtypes(self): + def dtypes(self) -> List[Tuple[str, str]]: """Returns all column names and their data types as a list. .. versionadded:: 1.3.0 @@ -1210,7 +1297,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] @property - def columns(self): + def columns(self) -> List[str]: """Returns all column names as a list. .. versionadded:: 1.3.0 @@ -1222,7 +1309,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return [f.name for f in self.schema.fields] - def colRegex(self, colName): + def colRegex(self, colName: str) -> Column: """ Selects column based on the column name specified as a regex and returns it as :class:`Column`. @@ -1251,7 +1338,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jc = self._jdf.colRegex(colName) return Column(jc) - def alias(self, alias): + def alias(self, alias: str) -> "DataFrame": """Returns a new :class:`DataFrame` with an alias set. .. versionadded:: 1.3.0 @@ -1274,7 +1361,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): assert isinstance(alias, str), "alias should be a string" return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx) - def crossJoin(self, other): + def crossJoin(self, other: "DataFrame") -> "DataFrame": """Returns the cartesian product with another :class:`DataFrame`. .. versionadded:: 2.1.0 @@ -1298,7 +1385,12 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jdf = self._jdf.crossJoin(other._jdf) return DataFrame(jdf, self.sql_ctx) - def join(self, other, on=None, how=None): + def join( + self, + other: "DataFrame", + on: Optional[Union[str, List[str], Column, List[Column]]] = None, + how: Optional[str] = None + ) -> "DataFrame": """Joins with another :class:`DataFrame`, using the given join expression. .. versionadded:: 1.3.0 @@ -1342,14 +1434,14 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ if on is not None and not isinstance(on, list): - on = [on] + on = [on] # type: ignore[assignment] if on is not None: if isinstance(on[0], str): - on = self._jseq(on) + on = self._jseq(cast(List[str], on)) else: assert isinstance(on[0], Column), "on should be Column or list of Column" - on = reduce(lambda x, y: x.__and__(y), on) + on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on)) on = on._jc if on is None and how is None: @@ -1366,16 +1458,16 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): # TODO(SPARK-22947): Fix the DataFrame API. def _joinAsOf( self, - other, - leftAsOfColumn, - rightAsOfColumn, - on=None, - how=None, + other: "DataFrame", + leftAsOfColumn: Union[str, Column], + rightAsOfColumn: Union[str, Column], + on: Optional[Union[str, List[str], Column, List[Column]]] = None, + how: Optional[str] = None, *, - tolerance=None, - allowExactMatches=True, - direction="backward", - ): + tolerance: Optional[Column] = None, + allowExactMatches: bool = True, + direction: str = "backward", + ) -> "DataFrame": """ Perform an as-of join. @@ -1448,20 +1540,20 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ if isinstance(leftAsOfColumn, str): leftAsOfColumn = self[leftAsOfColumn] - left_as_of_jcol = leftAsOfColumn._jc + left_as_of_jcol = cast(Column, leftAsOfColumn)._jc if isinstance(rightAsOfColumn, str): rightAsOfColumn = other[rightAsOfColumn] - right_as_of_jcol = rightAsOfColumn._jc + right_as_of_jcol = cast(Column, rightAsOfColumn)._jc if on is not None and not isinstance(on, list): - on = [on] + on = [on] # type: ignore[assignment] if on is not None: if isinstance(on[0], str): - on = self._jseq(on) + on = self._jseq(cast(List[str], on)) else: assert isinstance(on[0], Column), "on should be Column or list of Column" - on = reduce(lambda x, y: x.__and__(y), on) + on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on)) on = on._jc if how is None: @@ -1480,7 +1572,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): ) return DataFrame(jdf, self.sql_ctx) - def sortWithinPartitions(self, *cols, **kwargs): + def sortWithinPartitions( + self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: Any + ) -> "DataFrame": """Returns a new :class:`DataFrame` with each partition sorted by the specified column(s). .. versionadded:: 1.6.0 @@ -1510,7 +1604,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sql_ctx) - def sort(self, *cols, **kwargs): + def sort( + self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: Any + ) -> "DataFrame": """Returns a new :class:`DataFrame` sorted by the specified column(s). .. versionadded:: 1.3.0 @@ -1548,15 +1644,19 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): orderBy = sort - def _jseq(self, cols, converter=None): + def _jseq( + self, + cols: Sequence, + converter: Optional[Callable[..., Union["PrimitiveType", JavaObject]]] = None + ) -> JavaObject: """Return a JVM Seq of Columns from a list of Column or names""" - return _to_seq(self.sql_ctx._sc, cols, converter) + return _to_seq(self.sql_ctx._sc, cols, converter) # type: ignore[attr-defined] - def _jmap(self, jm): + def _jmap(self, jm: Dict) -> JavaObject: """Return a JVM Scala Map from a dict""" - return _to_scala_map(self.sql_ctx._sc, jm) + return _to_scala_map(self.sql_ctx._sc, jm) # type: ignore[attr-defined] - def _jcols(self, *cols): + def _jcols(self, *cols: "ColumnOrName") -> JavaObject: """Return a JVM Seq of Columns from a list of Column or column names If `cols` has only one list in it, cols[0] will be used as the list. @@ -1565,7 +1665,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): cols = cols[0] return self._jseq(cols, _to_java_column) - def _sort_cols(self, cols, kwargs): + def _sort_cols( + self, cols: Sequence[Union[str, Column, List[Union[str, Column]]]], kwargs: Dict[str, Any] + ) -> JavaObject: """ Return a JVM Seq of Columns that describes the sort order """ if not cols: @@ -1584,7 +1686,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) return self._jseq(jcols) - def describe(self, *cols): + def describe(self, *cols: Union[str, List[str]]) -> "DataFrame": """Computes basic statistics for numeric and string columns. .. versionadded:: 1.3.1 @@ -1628,11 +1730,11 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): DataFrame.summary """ if len(cols) == 1 and isinstance(cols[0], list): - cols = cols[0] + cols = cols[0] # type: ignore[assignment] jdf = self._jdf.describe(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) - def summary(self, *statistics): + def summary(self, *statistics: str) -> "DataFrame": """Computes specified statistics for numeric and string columns. Available statistics are: - count - mean @@ -1697,7 +1799,15 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jdf = self._jdf.summary(self._jseq(statistics)) return DataFrame(jdf, self.sql_ctx) - def head(self, n=None): + @overload + def head(self) -> Optional[Row]: + ... + + @overload + def head(self, n: int) -> List[Row]: + ... + + def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]: """Returns the first ``n`` rows. .. versionadded:: 1.3.0 @@ -1729,7 +1839,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return rs[0] if rs else None return self.take(n) - def first(self): + def first(self) -> Optional[Row]: """Returns the first row as a :class:`Row`. .. versionadded:: 1.3.0 @@ -1741,7 +1851,15 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return self.head() - def __getitem__(self, item): + @overload + def __getitem__(self, item: Union[int, str]) -> Column: + ... + + @overload + def __getitem__(self, item: Union[Column, List, Tuple]) -> "DataFrame": + ... + + def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> Union[Column, "DataFrame"]: """Returns the column as a :class:`Column`. .. versionadded:: 1.3.0 @@ -1770,7 +1888,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): else: raise TypeError("unexpected item type: %s" % type(item)) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Column: """Returns the :class:`Column` denoted by ``name``. .. versionadded:: 1.3.0 @@ -1786,7 +1904,15 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jc = self._jdf.apply(name) return Column(jc) - def select(self, *cols): + @overload + def select(self, *cols: "ColumnOrName") -> "DataFrame": + ... + + @overload + def select(self, __cols: Union[List[Column], List[str]]) -> "DataFrame": + ... + + def select(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] """Projects a set of expressions and returns a new :class:`DataFrame`. .. versionadded:: 1.3.0 @@ -1810,7 +1936,15 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jdf = self._jdf.select(self._jcols(*cols)) return DataFrame(jdf, self.sql_ctx) - def selectExpr(self, *expr): + @overload + def selectExpr(self, *expr: str) -> "DataFrame": + ... + + @overload + def selectExpr(self, *expr: List[str]) -> "DataFrame": + ... + + def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame": """Projects a set of SQL expressions and returns a new :class:`DataFrame`. This is a variant of :func:`select` that accepts SQL expressions. @@ -1823,11 +1957,11 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)] """ if len(expr) == 1 and isinstance(expr[0], list): - expr = expr[0] + expr = expr[0] # type: ignore[assignment] jdf = self._jdf.selectExpr(self._jseq(expr)) return DataFrame(jdf, self.sql_ctx) - def filter(self, condition): + def filter(self, condition: "ColumnOrName") -> "DataFrame": """Filters rows using the given condition. :func:`where` is an alias for :func:`filter`. @@ -1860,7 +1994,15 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): raise TypeError("condition should be string or Column") return DataFrame(jdf, self.sql_ctx) - def groupBy(self, *cols): + @overload + def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": + ... + + @overload + def groupBy(self, __cols: Union[List[Column], List[str]]) -> "GroupedData": + ... + + def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] """Groups the :class:`DataFrame` using the specified columns, so we can run aggregation on them. See :class:`GroupedData` for all the available aggregate functions. @@ -1890,7 +2032,15 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): from pyspark.sql.group import GroupedData return GroupedData(jgd, self) - def rollup(self, *cols): + @overload + def rollup(self, *cols: "ColumnOrName") -> "GroupedData": + ... + + @overload + def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData": + ... + + def rollup(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] """ Create a multi-dimensional rollup for the current :class:`DataFrame` using the specified columns, so we can run aggregation on them. @@ -1914,7 +2064,15 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): from pyspark.sql.group import GroupedData return GroupedData(jgd, self) - def cube(self, *cols): + @overload + def cube(self, *cols: "ColumnOrName") -> "GroupedData": + ... + + @overload + def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData": + ... + + def cube(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] """ Create a multi-dimensional cube for the current :class:`DataFrame` using the specified columns, so we can run aggregations on them. @@ -1940,7 +2098,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): from pyspark.sql.group import GroupedData return GroupedData(jgd, self) - def agg(self, *exprs): + def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame": """ Aggregate on the entire :class:`DataFrame` without groups (shorthand for ``df.groupBy().agg()``). @@ -1954,10 +2112,10 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): >>> df.agg(F.min(df.age)).collect() [Row(min(age)=2)] """ - return self.groupBy().agg(*exprs) + return self.groupBy().agg(*exprs) # type: ignore[arg-type] @since(3.3) - def observe(self, observation, *exprs): + def observe(self, observation: "Observation", *exprs: Column) -> "DataFrame": """Observe (named) metrics through an :class:`Observation` instance. A user can retrieve the metrics by accessing `Observation.get`. @@ -1996,7 +2154,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return observation._on(self, *exprs) @since(2.0) - def union(self, other): + def union(self, other: "DataFrame") -> "DataFrame": """ Return a new :class:`DataFrame` containing union of rows in this and another :class:`DataFrame`. @@ -2008,7 +2166,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return DataFrame(self._jdf.union(other._jdf), self.sql_ctx) @since(1.3) - def unionAll(self, other): + def unionAll(self, other: "DataFrame") -> "DataFrame": """ Return a new :class:`DataFrame` containing union of rows in this and another :class:`DataFrame`. @@ -2019,7 +2177,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return self.union(other) - def unionByName(self, other, allowMissingColumns=False): + def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> "DataFrame": """ Returns a new :class:`DataFrame` containing union of rows in this and another :class:`DataFrame`. @@ -2065,7 +2223,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return DataFrame(self._jdf.unionByName(other._jdf, allowMissingColumns), self.sql_ctx) @since(1.3) - def intersect(self, other): + def intersect(self, other: "DataFrame") -> "DataFrame": """ Return a new :class:`DataFrame` containing rows only in both this :class:`DataFrame` and another :class:`DataFrame`. @@ -2073,7 +2231,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) - def intersectAll(self, other): + def intersectAll(self, other: "DataFrame") -> "DataFrame": """ Return a new :class:`DataFrame` containing rows in both this :class:`DataFrame` and another :class:`DataFrame` while preserving duplicates. @@ -2100,7 +2258,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx) @since(1.3) - def subtract(self, other): + def subtract(self, other: "DataFrame") -> "DataFrame": """ Return a new :class:`DataFrame` containing rows in this :class:`DataFrame` but not in another :class:`DataFrame`. @@ -2109,7 +2267,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) - def dropDuplicates(self, subset=None): + def dropDuplicates(self, subset: Optional[List[str]] = None) -> "DataFrame": """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. @@ -2155,7 +2313,12 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jdf = self._jdf.dropDuplicates(self._jseq(subset)) return DataFrame(jdf, self.sql_ctx) - def dropna(self, how='any', thresh=None, subset=None): + def dropna( + self, + how: str = 'any', + thresh: Optional[int] = None, + subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None + ) -> "DataFrame": """Returns a new :class:`DataFrame` omitting rows with null values. :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are aliases of each other. @@ -2198,7 +2361,23 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx) - def fillna(self, value, subset=None): + @overload + def fillna( + self, + value: "LiteralType", + subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ..., + ) -> "DataFrame": + ... + + @overload + def fillna(self, value: Dict[str, "LiteralType"]) -> "DataFrame": + ... + + def fillna( + self, + value: Union["LiteralType", Dict[str, "LiteralType"]], + subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None + ) -> "DataFrame": """Replace null values, alias for ``na.fill()``. :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are aliases of each other. @@ -2269,7 +2448,49 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx) - def replace(self, to_replace, value=_NoValue, subset=None): + @overload + def replace( + self, + to_replace: "LiteralType", + value: "OptionalPrimitiveType", + subset: Optional[List[str]] = ..., + ) -> "DataFrame": + ... + + @overload + def replace( + self, + to_replace: List["LiteralType"], + value: List["OptionalPrimitiveType"], + subset: Optional[List[str]] = ..., + ) -> "DataFrame": + ... + + @overload + def replace( + self, + to_replace: Dict["LiteralType", "OptionalPrimitiveType"], + subset: Optional[List[str]] = ..., + ) -> "DataFrame": + ... + + @overload + def replace( + self, + to_replace: List["LiteralType"], + value: "OptionalPrimitiveType", + subset: Optional[List[str]] = ..., + ) -> "DataFrame": + ... + + def replace( # type: ignore[misc] + self, + to_replace: Union[ + "LiteralType", List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"] + ], + value: Optional[Union["OptionalPrimitiveType", List["OptionalPrimitiveType"]]] = _NoValue, + subset: Optional[List[str]] = None + ) -> "DataFrame": """Returns a new :class:`DataFrame` replacing a value with another value. :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are aliases of each other. @@ -2348,7 +2569,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): raise TypeError("value argument is required when to_replace is not a dictionary.") # Helper functions - def all_of(types): + def all_of(types: Union[Type, Tuple[Type, ...]]) -> Callable[[Iterable], bool]: """Given a type or tuple of types and a sequence of xs check if each x is instance of type(s) @@ -2357,7 +2578,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): >>> all_of(str)(["a", 1]) False """ - def all_of_(xs): + def all_of_(xs: Iterable) -> bool: return all(isinstance(x, types) for x in xs) return all_of_ @@ -2415,7 +2636,30 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return DataFrame( self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) - def approxQuantile(self, col, probabilities, relativeError): + @overload + def approxQuantile( + self, + col: str, + probabilities: Union[List[float], Tuple[float]], + relativeError: float, + ) -> List[float]: + ... + + @overload + def approxQuantile( + self, + col: Union[List[str], Tuple[str]], + probabilities: Union[List[float], Tuple[float]], + relativeError: float, + ) -> List[List[float]]: + ... + + def approxQuantile( + self, + col: Union[str, List[str], Tuple[str]], + probabilities: Union[List[float], Tuple[float]], + relativeError: float + ) -> Union[List[float], List[List[float]]]: """ Calculates the approximate quantiles of numerical columns of a :class:`DataFrame`. @@ -2474,7 +2718,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): if isinstance(col, tuple): col = list(col) elif isStr: - col = [col] + col = [cast(str, col)] for c in col: if not isinstance(c, str): @@ -2500,7 +2744,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jaq_list = [list(j) for j in jaq] return jaq_list[0] if isStr else jaq_list - def corr(self, col1, col2, method=None): + def corr(self, col1: str, col2: str, method: Optional[str] = None) -> float: """ Calculates the correlation of two columns of a :class:`DataFrame` as a double value. Currently only supports the Pearson Correlation Coefficient. @@ -2528,7 +2772,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): "coefficient is supported.") return self._jdf.stat().corr(col1, col2, method) - def cov(self, col1, col2): + def cov(self, col1: str, col2: str) -> float: """ Calculate the sample covariance for the given columns, specified by their names, as a double value. :func:`DataFrame.cov` and :func:`DataFrameStatFunctions.cov` are aliases. @@ -2548,7 +2792,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): raise TypeError("col2 should be a string.") return self._jdf.stat().cov(col1, col2) - def crosstab(self, col1, col2): + def crosstab(self, col1: str, col2: str) -> "DataFrame": """ Computes a pair-wise frequency table of the given columns. Also known as a contingency table. The number of distinct values for each column should be less than 1e4. At most 1e6 @@ -2575,7 +2819,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): raise TypeError("col2 should be a string.") return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) - def freqItems(self, cols, support=None): + def freqItems( + self, cols: Union[List[str], Tuple[str]], support: Optional[float] = None + ) -> "DataFrame": """ Finding frequent items for columns, possibly with false positives. Using the frequent element count algorithm described in @@ -2607,7 +2853,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): support = 0.01 return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sql_ctx) - def withColumn(self, colName, col): + def withColumn(self, colName: str, col: Column) -> "DataFrame": """ Returns a new :class:`DataFrame` by adding a column or replacing the existing column that has the same name. @@ -2641,7 +2887,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): raise TypeError("col should be Column") return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx) - def withColumnRenamed(self, existing, new): + def withColumnRenamed(self, existing: str, new: str) -> "DataFrame": """Returns a new :class:`DataFrame` by renaming an existing column. This is a no-op if schema doesn't contain the given column name. @@ -2661,7 +2907,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx) - def withMetadata(self, columnName, metadata): + def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> "DataFrame": """Returns a new :class:`DataFrame` by updating an existing column with metadata. .. versionadded:: 3.3.0 @@ -2681,12 +2927,20 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ if not isinstance(metadata, dict): raise TypeError("metadata should be a dict") - sc = SparkContext._active_spark_context + sc = SparkContext._active_spark_context # type: ignore[attr-defined] jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson( json.dumps(metadata)) return DataFrame(self._jdf.withMetadata(columnName, jmeta), self.sql_ctx) - def drop(self, *cols): + @overload + def drop(self, cols: "ColumnOrName") -> "DataFrame": + ... + + @overload + def drop(self, *cols: str) -> "DataFrame": + ... + + def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] """Returns a new :class:`DataFrame` that drops the specified column. This is a no-op if schema doesn't contain the given column name(s). @@ -2730,7 +2984,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return DataFrame(jdf, self.sql_ctx) - def toDF(self, *cols): + def toDF(self, *cols: "ColumnOrName") -> "DataFrame": """Returns a new :class:`DataFrame` that with new specified column names Parameters @@ -2746,7 +3000,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jdf = self._jdf.toDF(self._jseq(cols)) return DataFrame(jdf, self.sql_ctx) - def transform(self, func): + def transform(self, func: Callable[["DataFrame"], "DataFrame"]) -> "DataFrame": """Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations. .. versionadded:: 3.0.0 @@ -2777,7 +3031,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): "should have been DataFrame." % type(result) return result - def sameSemantics(self, other): + def sameSemantics(self, other: "DataFrame") -> bool: """ Returns `True` when the logical query plans inside both :class:`DataFrame`\\s are equal and therefore return same results. @@ -2811,7 +3065,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): % type(other)) return self._jdf.sameSemantics(other._jdf) - def semanticHash(self): + def semanticHash(self) -> int: """ Returns a hash code of the logical query plan against this :class:`DataFrame`. @@ -2833,7 +3087,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return self._jdf.semanticHash() - def inputFiles(self): + def inputFiles(self) -> List[str]: """ Returns a best-effort snapshot of the files that compose this :class:`DataFrame`. This method simply asks each constituent BaseRelation for its respective files and @@ -2870,7 +3124,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): sinceversion=1.4, doc=":func:`drop_duplicates` is an alias for :func:`dropDuplicates`.") - def writeTo(self, table): + def writeTo(self, table: str) -> DataFrameWriterV2: """ Create a write configuration builder for v2 sources. @@ -2889,7 +3143,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): """ return DataFrameWriterV2(self, table) - def to_pandas_on_spark(self, index_col=None): + def to_pandas_on_spark( + self, index_col: Optional[Union[str, List[str]]] = None + ) -> "PandasOnSparkDataFrame": """ Converts the existing DataFrame into a pandas-on-Spark DataFrame. @@ -2935,17 +3191,20 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): c 3 """ from pyspark.pandas.namespace import _get_index_map - from pyspark.pandas.frame import DataFrame + from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame from pyspark.pandas.internal import InternalFrame index_spark_columns, index_names = _get_index_map(self, index_col) internal = InternalFrame( - spark_frame=self, index_spark_columns=index_spark_columns, index_names=index_names + spark_frame=self, index_spark_columns=index_spark_columns, + index_names=index_names # type: ignore[arg-type] ) - return DataFrame(internal) + return PandasOnSparkDataFrame(internal) # Keep to_koalas for backward compatibility for now. - def to_koalas(self, index_col=None): + def to_koalas( + self, index_col: Optional[Union[str, List[str]]] = None + ) -> "PandasOnSparkDataFrame": warnings.warn( "DataFrame.to_koalas is deprecated. Use DataFrame.to_pandas_on_spark instead.", FutureWarning, @@ -2953,11 +3212,11 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return self.to_pandas_on_spark(index_col) -def _to_scala_map(sc, jm): +def _to_scala_map(sc: SparkContext, jm: Dict) -> JavaObject: """ Convert a dict into a JVM Map. """ - return sc._jvm.PythonUtils.toScalaMap(jm) + return sc._jvm.PythonUtils.toScalaMap(jm) # type: ignore[attr-defined] class DataFrameNaFunctions(object): @@ -2966,21 +3225,70 @@ class DataFrameNaFunctions(object): .. versionadded:: 1.4 """ - def __init__(self, df): + def __init__(self, df: DataFrame): self.df = df - def drop(self, how='any', thresh=None, subset=None): + def drop( + self, how: str = 'any', thresh: Optional[int] = None, subset: Optional[List[str]] = None + ) -> DataFrame: return self.df.dropna(how=how, thresh=thresh, subset=subset) drop.__doc__ = DataFrame.dropna.__doc__ - def fill(self, value, subset=None): - return self.df.fillna(value=value, subset=subset) + @overload + def fill( + self, value: "LiteralType", subset: Optional[List[str]] = ... + ) -> DataFrame: + ... + + @overload + def fill(self, value: Dict[str, "LiteralType"]) -> DataFrame: + ... + + def fill( + self, + value: Union["LiteralType", Dict[str, "LiteralType"]], + subset: Optional[List[str]] = None + ) -> DataFrame: + return self.df.fillna(value=value, subset=subset) # type: ignore[arg-type] fill.__doc__ = DataFrame.fillna.__doc__ - def replace(self, to_replace, value=_NoValue, subset=None): - return self.df.replace(to_replace, value, subset) + @overload + def replace( + self, + to_replace: List["LiteralType"], + value: List["OptionalPrimitiveType"], + subset: Optional[List[str]] = ..., + ) -> DataFrame: + ... + + @overload + def replace( + self, + to_replace: Dict["LiteralType", "OptionalPrimitiveType"], + subset: Optional[List[str]] = ..., + ) -> DataFrame: + ... + + @overload + def replace( + self, + to_replace: List["LiteralType"], + value: "OptionalPrimitiveType", + subset: Optional[List[str]] = ..., + ) -> DataFrame: + ... + + def replace( # type: ignore[misc] + self, + to_replace: Union[ + List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"] + ], + value: Optional[Union["OptionalPrimitiveType", List["OptionalPrimitiveType"]]] = _NoValue, + subset: Optional[List[str]] = None + ) -> DataFrame: + return self.df.replace(to_replace, value, subset) # type: ignore[arg-type] replace.__doc__ = DataFrame.replace.__doc__ @@ -2991,41 +3299,66 @@ class DataFrameStatFunctions(object): .. versionadded:: 1.4 """ - def __init__(self, df): + def __init__(self, df: DataFrame): self.df = df - def approxQuantile(self, col, probabilities, relativeError): + @overload + def approxQuantile( + self, + col: str, + probabilities: Union[List[float], Tuple[float]], + relativeError: float, + ) -> List[float]: + ... + + @overload + def approxQuantile( + self, + col: Union[List[str], Tuple[str]], + probabilities: Union[List[float], Tuple[float]], + relativeError: float, + ) -> List[List[float]]: + ... + + def approxQuantile( # type: ignore[misc] + self, + col: Union[str, List[str], Tuple[str]], + probabilities: Union[List[float], Tuple[float]], + relativeError: float + ) -> Union[List[float], List[List[float]]]: return self.df.approxQuantile(col, probabilities, relativeError) approxQuantile.__doc__ = DataFrame.approxQuantile.__doc__ - def corr(self, col1, col2, method=None): + def corr(self, col1: str, col2: str, method: Optional[str] = None) -> float: return self.df.corr(col1, col2, method) corr.__doc__ = DataFrame.corr.__doc__ - def cov(self, col1, col2): + def cov(self, col1: str, col2: str) -> float: return self.df.cov(col1, col2) cov.__doc__ = DataFrame.cov.__doc__ - def crosstab(self, col1, col2): + def crosstab(self, col1: str, col2: str) -> DataFrame: return self.df.crosstab(col1, col2) crosstab.__doc__ = DataFrame.crosstab.__doc__ - def freqItems(self, cols, support=None): + def freqItems(self, cols: List[str], support: Optional[float] = None) -> DataFrame: return self.df.freqItems(cols, support) freqItems.__doc__ = DataFrame.freqItems.__doc__ - def sampleBy(self, col, fractions, seed=None): + def sampleBy( + self, col: str, fractions: Dict[Any, float], seed: Optional[int] = None + ) -> DataFrame: return self.df.sampleBy(col, fractions, seed) sampleBy.__doc__ = DataFrame.sampleBy.__doc__ -def _test(): +def _test() -> None: import doctest from pyspark.context import SparkContext from pyspark.sql import Row, SQLContext, SparkSession diff --git a/python/pyspark/sql/dataframe.pyi b/python/pyspark/sql/dataframe.pyi deleted file mode 100644 index d903a79..0000000 --- a/python/pyspark/sql/dataframe.pyi +++ /dev/null @@ -1,351 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import overload -from typing import ( - Any, - Callable, - Dict, - Iterator, - List, - Optional, - Tuple, - Union, -) - -from py4j.java_gateway import JavaObject # type: ignore[import] - -from pyspark.sql._typing import ColumnOrName, LiteralType, OptionalPrimitiveType -from pyspark._typing import PrimitiveType -from pyspark.sql.types import ( # noqa: F401 - StructType, - StructField, - StringType, - IntegerType, - Row, -) # noqa: F401 -from pyspark.sql.context import SQLContext -from pyspark.sql.group import GroupedData -from pyspark.sql.observation import Observation -from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2 -from pyspark.sql.streaming import DataStreamWriter -from pyspark.sql.column import Column -from pyspark.rdd import RDD -from pyspark.storagelevel import StorageLevel - -from pyspark.sql.pandas.conversion import PandasConversionMixin -from pyspark.sql.pandas.map_ops import PandasMapOpsMixin -from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame - -class DataFrame(PandasMapOpsMixin, PandasConversionMixin): - sql_ctx: SQLContext - is_cached: bool - def __init__(self, jdf: JavaObject, sql_ctx: SQLContext) -> None: ... - @property - def rdd(self) -> RDD[Row]: ... - @property - def na(self) -> DataFrameNaFunctions: ... - @property - def stat(self) -> DataFrameStatFunctions: ... - def toJSON(self, use_unicode: bool = ...) -> RDD[str]: ... - def registerTempTable(self, name: str) -> None: ... - def createTempView(self, name: str) -> None: ... - def createOrReplaceTempView(self, name: str) -> None: ... - def createGlobalTempView(self, name: str) -> None: ... - @property - def write(self) -> DataFrameWriter: ... - @property - def writeStream(self) -> DataStreamWriter: ... - @property - def schema(self) -> StructType: ... - def printSchema(self) -> None: ... - def explain( - self, extended: Optional[Union[bool, str]] = ..., mode: Optional[str] = ... - ) -> None: ... - def exceptAll(self, other: DataFrame) -> DataFrame: ... - def isLocal(self) -> bool: ... - @property - def isStreaming(self) -> bool: ... - def show( - self, n: int = ..., truncate: Union[bool, int] = ..., vertical: bool = ... - ) -> None: ... - def checkpoint(self, eager: bool = ...) -> DataFrame: ... - def localCheckpoint(self, eager: bool = ...) -> DataFrame: ... - def withWatermark( - self, eventTime: str, delayThreshold: str - ) -> DataFrame: ... - def hint(self, name: str, *parameters: Union[PrimitiveType, List[PrimitiveType]]) -> DataFrame: ... - def count(self) -> int: ... - def collect(self) -> List[Row]: ... - def toLocalIterator(self, prefetchPartitions: bool = ...) -> Iterator[Row]: ... - def limit(self, num: int) -> DataFrame: ... - def take(self, num: int) -> List[Row]: ... - def tail(self, num: int) -> List[Row]: ... - def foreach(self, f: Callable[[Row], None]) -> None: ... - def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: ... - def cache(self) -> DataFrame: ... - def persist(self, storageLevel: StorageLevel = ...) -> DataFrame: ... - @property - def storageLevel(self) -> StorageLevel: ... - def unpersist(self, blocking: bool = ...) -> DataFrame: ... - def coalesce(self, numPartitions: int) -> DataFrame: ... - @overload - def repartition(self, numPartitions: int, *cols: ColumnOrName) -> DataFrame: ... - @overload - def repartition(self, *cols: ColumnOrName) -> DataFrame: ... - @overload - def repartitionByRange( - self, numPartitions: int, *cols: ColumnOrName - ) -> DataFrame: ... - @overload - def repartitionByRange(self, *cols: ColumnOrName) -> DataFrame: ... - def distinct(self) -> DataFrame: ... - @overload - def sample(self, fraction: float, seed: Optional[int] = ...) -> DataFrame: ... - @overload - def sample( - self, - withReplacement: Optional[bool], - fraction: float, - seed: Optional[int] = ..., - ) -> DataFrame: ... - def sampleBy( - self, col: ColumnOrName, fractions: Dict[Any, float], seed: Optional[int] = ... - ) -> DataFrame: ... - def randomSplit( - self, weights: List[float], seed: Optional[int] = ... - ) -> List[DataFrame]: ... - @property - def dtypes(self) -> List[Tuple[str, str]]: ... - @property - def columns(self) -> List[str]: ... - def colRegex(self, colName: str) -> Column: ... - def alias(self, alias: str) -> DataFrame: ... - def crossJoin(self, other: DataFrame) -> DataFrame: ... - def join( - self, - other: DataFrame, - on: Optional[Union[str, List[str], Column, List[Column]]] = ..., - how: Optional[str] = ..., - ) -> DataFrame: ... - def sortWithinPartitions( - self, - *cols: Union[str, Column, List[Union[str, Column]]], - ascending: Union[bool, List[bool]] = ... - ) -> DataFrame: ... - def sort( - self, - *cols: Union[str, Column, List[Union[str, Column]]], - ascending: Union[bool, List[bool]] = ... - ) -> DataFrame: ... - def orderBy( - self, - *cols: Union[str, Column, List[Union[str, Column]]], - ascending: Union[bool, List[bool]] = ... - ) -> DataFrame: ... - def describe(self, *cols: Union[str, List[str]]) -> DataFrame: ... - def summary(self, *statistics: str) -> DataFrame: ... - @overload - def head(self) -> Row: ... - @overload - def head(self, n: int) -> List[Row]: ... - def first(self) -> Row: ... - def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> Column: ... - def __getattr__(self, name: str) -> Column: ... - @overload - def select(self, *cols: ColumnOrName) -> DataFrame: ... - @overload - def select(self, __cols: Union[List[Column], List[str]]) -> DataFrame: ... - @overload - def selectExpr(self, *expr: str) -> DataFrame: ... - @overload - def selectExpr(self, *expr: List[str]) -> DataFrame: ... - def filter(self, condition: ColumnOrName) -> DataFrame: ... - @overload - def groupBy(self, *cols: ColumnOrName) -> GroupedData: ... - @overload - def groupBy(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ... - @overload - def rollup(self, *cols: ColumnOrName) -> GroupedData: ... - @overload - def rollup(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ... - @overload - def cube(self, *cols: ColumnOrName) -> GroupedData: ... - @overload - def cube(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ... - def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: ... - def observe(self, observation: Observation, *exprs: Column) -> DataFrame: ... - def union(self, other: DataFrame) -> DataFrame: ... - def unionAll(self, other: DataFrame) -> DataFrame: ... - def unionByName( - self, other: DataFrame, allowMissingColumns: bool = ... - ) -> DataFrame: ... - def intersect(self, other: DataFrame) -> DataFrame: ... - def intersectAll(self, other: DataFrame) -> DataFrame: ... - def subtract(self, other: DataFrame) -> DataFrame: ... - def dropDuplicates(self, subset: Optional[List[str]] = ...) -> DataFrame: ... - def dropna( - self, - how: str = ..., - thresh: Optional[int] = ..., - subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ..., - ) -> DataFrame: ... - @overload - def fillna( - self, - value: LiteralType, - subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ..., - ) -> DataFrame: ... - @overload - def fillna(self, value: Dict[str, LiteralType]) -> DataFrame: ... - @overload - def replace( - self, - to_replace: LiteralType, - value: OptionalPrimitiveType, - subset: Optional[List[str]] = ..., - ) -> DataFrame: ... - @overload - def replace( - self, - to_replace: List[LiteralType], - value: List[OptionalPrimitiveType], - subset: Optional[List[str]] = ..., - ) -> DataFrame: ... - @overload - def replace( - self, - to_replace: Dict[LiteralType, OptionalPrimitiveType], - subset: Optional[List[str]] = ..., - ) -> DataFrame: ... - @overload - def replace( - self, - to_replace: List[LiteralType], - value: OptionalPrimitiveType, - subset: Optional[List[str]] = ..., - ) -> DataFrame: ... - @overload - def approxQuantile( - self, - col: str, - probabilities: Union[List[float], Tuple[float]], - relativeError: float, - ) -> List[float]: ... - @overload - def approxQuantile( - self, - col: Union[List[str], Tuple[str]], - probabilities: Union[List[float], Tuple[float]], - relativeError: float, - ) -> List[List[float]]: ... - def corr(self, col1: str, col2: str, method: Optional[str] = ...) -> float: ... - def cov(self, col1: str, col2: str) -> float: ... - def crosstab(self, col1: str, col2: str) -> DataFrame: ... - def freqItems( - self, cols: Union[List[str], Tuple[str]], support: Optional[float] = ... - ) -> DataFrame: ... - def withColumn(self, colName: str, col: Column) -> DataFrame: ... - def withColumnRenamed(self, existing: str, new: str) -> DataFrame: ... - @overload - def drop(self, cols: ColumnOrName) -> DataFrame: ... - @overload - def drop(self, *cols: str) -> DataFrame: ... - def toDF(self, *cols: ColumnOrName) -> DataFrame: ... - def transform(self, func: Callable[[DataFrame], DataFrame]) -> DataFrame: ... - @overload - def groupby(self, *cols: ColumnOrName) -> GroupedData: ... - @overload - def groupby(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ... - def drop_duplicates(self, subset: Optional[List[str]] = ...) -> DataFrame: ... - def where(self, condition: ColumnOrName) -> DataFrame: ... - def sameSemantics(self, other: DataFrame) -> bool: ... - def semanticHash(self) -> int: ... - def inputFiles(self) -> List[str]: ... - def writeTo(self, table: str) -> DataFrameWriterV2: ... - def to_pandas_on_spark(self, index_col: Optional[Union[str, List[str]]] = None) -> PandasOnSparkDataFrame: ... - -class DataFrameNaFunctions: - df: DataFrame - def __init__(self, df: DataFrame) -> None: ... - def drop( - self, - how: str = ..., - thresh: Optional[int] = ..., - subset: Optional[List[str]] = ..., - ) -> DataFrame: ... - @overload - def fill( - self, value: LiteralType, subset: Optional[List[str]] = ... - ) -> DataFrame: ... - @overload - def fill(self, value: Dict[str, LiteralType]) -> DataFrame: ... - @overload - def replace( - self, - to_replace: LiteralType, - value: OptionalPrimitiveType, - subset: Optional[List[str]] = ..., - ) -> DataFrame: ... - @overload - def replace( - self, - to_replace: List[LiteralType], - value: List[OptionalPrimitiveType], - subset: Optional[List[str]] = ..., - ) -> DataFrame: ... - @overload - def replace( - self, - to_replace: Dict[LiteralType, OptionalPrimitiveType], - subset: Optional[List[str]] = ..., - ) -> DataFrame: ... - @overload - def replace( - self, - to_replace: List[LiteralType], - value: OptionalPrimitiveType, - subset: Optional[List[str]] = ..., - ) -> DataFrame: ... - -class DataFrameStatFunctions: - df: DataFrame - def __init__(self, df: DataFrame) -> None: ... - @overload - def approxQuantile( - self, - col: str, - probabilities: Union[List[float], Tuple[float]], - relativeError: float, - ) -> List[float]: ... - @overload - def approxQuantile( - self, - col: Union[List[str], Tuple[str]], - probabilities: Union[List[float], Tuple[float]], - relativeError: float, - ) -> List[List[float]]: ... - def corr(self, col1: str, col2: str, method: Optional[str] = ...) -> float: ... - def cov(self, col1: str, col2: str) -> float: ... - def crosstab(self, col1: str, col2: str) -> DataFrame: ... - def freqItems( - self, cols: List[str], support: Optional[float] = ... - ) -> DataFrame: ... - def sampleBy( - self, col: str, fractions: Dict[Any, float], seed: Optional[int] = ... - ) -> DataFrame: ... diff --git a/python/pyspark/sql/observation.py b/python/pyspark/sql/observation.py index 3e8a0d1..48d8176 100644 --- a/python/pyspark/sql/observation.py +++ b/python/pyspark/sql/observation.py @@ -99,7 +99,7 @@ class Observation: assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" assert self._jo is None, "an Observation can be used with a DataFrame only once" - self._jvm = df._sc._jvm # type: ignore[assignment] + self._jvm = df._sc._jvm # type: ignore[assignment, attr-defined] cls = self._jvm.org.apache.spark.sql.Observation # type: ignore[attr-defined] self._jo = cls(self._name) if self._name is not None else cls() observed_df = self._jo.on( # type: ignore[attr-defined] --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org