This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9995e62  [SPARK-36885][PYTHON] Inline type hints for 
pyspark.sql.dataframe
9995e62 is described below

commit 9995e623f7f65fdd3b1dc3cd4e0140a7cf4bc4a0
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Tue Oct 12 09:17:14 2021 +0900

    [SPARK-36885][PYTHON] Inline type hints for pyspark.sql.dataframe
    
    ### What changes were proposed in this pull request?
    
    Inline type hints from `python/pyspark/sql/dataframe.pyi` to 
`python/pyspark/sql/dataframe.py`.
    
    ### Why are the changes needed?
    
    Currently, there is type hint stub files `python/pyspark/sql/dataframe.pyi` 
to show the expected types for functions, but we can also take advantage of 
static type checking within the functions by inlining the type hints.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #34225 from ueshin/issues/SPARK-36885/inline_typehints_dataframe.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/__init__.pyi       |   2 +-
 python/pyspark/sql/dataframe.py   | 671 ++++++++++++++++++++++++++++----------
 python/pyspark/sql/dataframe.pyi  | 351 --------------------
 python/pyspark/sql/observation.py |   2 +-
 4 files changed, 504 insertions(+), 522 deletions(-)

diff --git a/python/pyspark/__init__.pyi b/python/pyspark/__init__.pyi
index f85319b..35df545 100644
--- a/python/pyspark/__init__.pyi
+++ b/python/pyspark/__init__.pyi
@@ -71,7 +71,7 @@ def since(version: Union[str, float]) -> Callable[[T], T]: ...
 def copy_func(
     f: F,
     name: Optional[str] = ...,
-    sinceversion: Optional[str] = ...,
+    sinceversion: Optional[Union[str, float]] = ...,
     doc: Optional[str] = ...,
 ) -> F: ...
 def keyword_only(func: F) -> F: ...
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 8d4c94f..339f8f8 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -22,22 +22,41 @@ import warnings
 from collections.abc import Iterable
 from functools import reduce
 from html import escape as html_escape
+from typing import (
+    Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Type, 
Union,
+    cast, overload, TYPE_CHECKING
+)
 
-from pyspark import copy_func, since, _NoValue
+from py4j.java_gateway import JavaObject  # type: ignore[import]
+
+from pyspark import copy_func, since, _NoValue  # type: ignore[attr-defined]
 from pyspark.context import SparkContext
-from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket
+from pyspark.rdd import (  # type: ignore[attr-defined]
+    RDD, _load_from_socket, _local_iterator_from_socket
+)
 from pyspark.serializers import BatchedSerializer, PickleSerializer, \
     UTF8Deserializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.traceback_utils import SCCallSiteSync
-from pyspark.sql.types import _parse_datatype_json_string
-from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
+from pyspark.sql.types import _parse_datatype_json_string  # type: 
ignore[attr-defined]
+from pyspark.sql.column import (  # type: ignore[attr-defined]
+    Column, _to_seq, _to_list, _to_java_column
+)
 from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
 from pyspark.sql.streaming import DataStreamWriter
-from pyspark.sql.types import StructType, StructField, StringType, IntegerType
+from pyspark.sql.types import StructType, StructField, StringType, 
IntegerType, Row
 from pyspark.sql.pandas.conversion import PandasConversionMixin
 from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
 
+if TYPE_CHECKING:
+    from pyspark._typing import PrimitiveType
+    from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+    from pyspark.sql._typing import ColumnOrName, LiteralType, 
OptionalPrimitiveType
+    from pyspark.sql.context import SQLContext
+    from pyspark.sql.group import GroupedData
+    from pyspark.sql.observation import Observation
+
+
 __all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"]
 
 
@@ -68,42 +87,49 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
     .. versionadded:: 1.3.0
     """
 
-    def __init__(self, jdf, sql_ctx):
+    def __init__(self, jdf: JavaObject, sql_ctx: "SQLContext"):
         self._jdf = jdf
         self.sql_ctx = sql_ctx
-        self._sc = sql_ctx and sql_ctx._sc
+        self._sc = cast(
+            SparkContext,
+            sql_ctx and sql_ctx._sc  # type: ignore[attr-defined]
+        )
         self.is_cached = False
-        self._schema = None  # initialized lazily
-        self._lazy_rdd = None
+        # initialized lazily
+        self._schema: Optional[StructType] = None
+        self._lazy_rdd: Optional[RDD[Row]] = None
         # Check whether _repr_html is supported or not, we use it to avoid 
calling _jdf twice
         # by __repr__ and _repr_html_ while eager evaluation opened.
         self._support_repr_html = False
 
-    @property
+    @property  # type: ignore[misc]
     @since(1.3)
-    def rdd(self):
+    def rdd(self) -> "RDD[Row]":
         """Returns the content as an :class:`pyspark.RDD` of :class:`Row`.
         """
         if self._lazy_rdd is None:
             jrdd = self._jdf.javaToPython()
-            self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, 
BatchedSerializer(PickleSerializer()))
+            self._lazy_rdd = RDD(
+                jrdd, self.sql_ctx._sc,  # type: ignore[attr-defined]
+                BatchedSerializer(PickleSerializer())
+            )
         return self._lazy_rdd
 
-    @property
+    @property  # type: ignore[misc]
     @since("1.3.1")
-    def na(self):
+    def na(self) -> "DataFrameNaFunctions":
         """Returns a :class:`DataFrameNaFunctions` for handling missing values.
         """
         return DataFrameNaFunctions(self)
 
-    @property
+    @property  # type: ignore[misc]
     @since(1.4)
-    def stat(self):
+    def stat(self) -> "DataFrameStatFunctions":
         """Returns a :class:`DataFrameStatFunctions` for statistic functions.
         """
         return DataFrameStatFunctions(self)
 
-    def toJSON(self, use_unicode=True):
+    def toJSON(self, use_unicode: bool = True) -> "RDD[str]":
         """Converts a :class:`DataFrame` into a :class:`RDD` of string.
 
         Each row is turned into a JSON document as one element in the returned 
RDD.
@@ -118,7 +144,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         rdd = self._jdf.toJSON()
         return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
 
-    def registerTempTable(self, name):
+    def registerTempTable(self, name: str) -> None:
         """Registers this :class:`DataFrame` as a temporary table using the 
given name.
 
         The lifetime of this temporary table is tied to the 
:class:`SparkSession`
@@ -145,7 +171,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         )
         self._jdf.createOrReplaceTempView(name)
 
-    def createTempView(self, name):
+    def createTempView(self, name: str) -> None:
         """Creates a local temporary view with this :class:`DataFrame`.
 
         The lifetime of this temporary table is tied to the 
:class:`SparkSession`
@@ -171,7 +197,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         self._jdf.createTempView(name)
 
-    def createOrReplaceTempView(self, name):
+    def createOrReplaceTempView(self, name: str) -> None:
         """Creates or replaces a local temporary view with this 
:class:`DataFrame`.
 
         The lifetime of this temporary table is tied to the 
:class:`SparkSession`
@@ -193,7 +219,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         self._jdf.createOrReplaceTempView(name)
 
-    def createGlobalTempView(self, name):
+    def createGlobalTempView(self, name: str) -> None:
         """Creates a global temporary view with this :class:`DataFrame`.
 
         The lifetime of this temporary view is tied to this Spark application.
@@ -218,7 +244,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         self._jdf.createGlobalTempView(name)
 
-    def createOrReplaceGlobalTempView(self, name):
+    def createOrReplaceGlobalTempView(self, name: str) -> None:
         """Creates or replaces a global temporary view using the given name.
 
         The lifetime of this temporary view is tied to this Spark application.
@@ -240,7 +266,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         self._jdf.createOrReplaceGlobalTempView(name)
 
     @property
-    def write(self):
+    def write(self) -> DataFrameWriter:
         """
         Interface for saving the content of the non-streaming 
:class:`DataFrame` out into external
         storage.
@@ -254,7 +280,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         return DataFrameWriter(self)
 
     @property
-    def writeStream(self):
+    def writeStream(self) -> DataStreamWriter:
         """
         Interface for saving the content of the streaming :class:`DataFrame` 
out into external
         storage.
@@ -272,7 +298,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         return DataStreamWriter(self)
 
     @property
-    def schema(self):
+    def schema(self) -> StructType:
         """Returns the schema of this :class:`DataFrame` as a 
:class:`pyspark.sql.types.StructType`.
 
         .. versionadded:: 1.3.0
@@ -290,7 +316,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
                     "Unable to parse datatype from schema. %s" % e) from e
         return self._schema
 
-    def printSchema(self):
+    def printSchema(self) -> None:
         """Prints out the schema in the tree format.
 
         .. versionadded:: 1.3.0
@@ -305,7 +331,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         print(self._jdf.schema().treeString())
 
-    def explain(self, extended=None, mode=None):
+    def explain(
+        self, extended: Optional[Union[bool, str]] = None, mode: Optional[str] 
= None
+    ) -> None:
         """Prints the (logical and physical) plans to the console for 
debugging purpose.
 
         .. versionadded:: 1.3.0
@@ -390,13 +418,16 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         elif is_extended_case:
             explain_mode = "extended" if extended else "simple"
         elif is_mode_case:
-            explain_mode = mode
+            explain_mode = cast(str, mode)
         elif is_extended_as_mode:
-            explain_mode = extended
+            explain_mode = cast(str, extended)
 
-        
print(self._sc._jvm.PythonSQLUtils.explainString(self._jdf.queryExecution(), 
explain_mode))
+        print(
+            self._sc._jvm  # type: ignore[attr-defined]
+                .PythonSQLUtils.explainString(self._jdf.queryExecution(), 
explain_mode)
+        )
 
-    def exceptAll(self, other):
+    def exceptAll(self, other: "DataFrame") -> "DataFrame":
         """Return a new :class:`DataFrame` containing rows in this 
:class:`DataFrame` but
         not in another :class:`DataFrame` while preserving duplicates.
 
@@ -425,14 +456,14 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         return DataFrame(self._jdf.exceptAll(other._jdf), self.sql_ctx)
 
     @since(1.3)
-    def isLocal(self):
+    def isLocal(self) -> bool:
         """Returns ``True`` if the :func:`collect` and :func:`take` methods 
can be run locally
         (without any Spark executors).
         """
         return self._jdf.isLocal()
 
     @property
-    def isStreaming(self):
+    def isStreaming(self) -> bool:
         """Returns ``True`` if this :class:`DataFrame` contains one or more 
sources that
         continuously return data as it arrives. A :class:`DataFrame` that 
reads data from a
         streaming source must be executed as a :class:`StreamingQuery` using 
the :func:`start`
@@ -448,7 +479,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         return self._jdf.isStreaming()
 
-    def show(self, n=20, truncate=True, vertical=False):
+    def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: 
bool = False) -> None:
         """Prints the first ``n`` rows to the console.
 
         .. versionadded:: 1.3.0
@@ -509,26 +540,33 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
 
             print(self._jdf.showString(n, int_truncate, vertical))
 
-    def __repr__(self):
-        if not self._support_repr_html and 
self.sql_ctx._conf.isReplEagerEvalEnabled():
+    def __repr__(self) -> str:
+        if (
+            not self._support_repr_html
+            and self.sql_ctx._conf.isReplEagerEvalEnabled()  # type: 
ignore[attr-defined]
+        ):
             vertical = False
             return self._jdf.showString(
-                self.sql_ctx._conf.replEagerEvalMaxNumRows(),
-                self.sql_ctx._conf.replEagerEvalTruncate(), vertical)
+                self.sql_ctx._conf.replEagerEvalMaxNumRows(),  # type: 
ignore[attr-defined]
+                self.sql_ctx._conf.replEagerEvalTruncate(), vertical)  # type: 
ignore[attr-defined]
         else:
             return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in 
self.dtypes))
 
-    def _repr_html_(self):
+    def _repr_html_(self) -> Optional[str]:
         """Returns a :class:`DataFrame` with html code when you enabled eager 
evaluation
         by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are
         using support eager evaluation with HTML.
         """
         if not self._support_repr_html:
             self._support_repr_html = True
-        if self.sql_ctx._conf.isReplEagerEvalEnabled():
-            max_num_rows = max(self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0)
+        if self.sql_ctx._conf.isReplEagerEvalEnabled():  # type: 
ignore[attr-defined]
+            max_num_rows = max(
+                self.sql_ctx._conf.replEagerEvalMaxNumRows(), 0  # type: 
ignore[attr-defined]
+            )
             sock_info = self._jdf.getRowsToPython(
-                max_num_rows, self.sql_ctx._conf.replEagerEvalTruncate())
+                max_num_rows,
+                self.sql_ctx._conf.replEagerEvalTruncate()  # type: 
ignore[attr-defined]
+            )
             rows = list(_load_from_socket(sock_info, 
BatchedSerializer(PickleSerializer())))
             head = rows[0]
             row_data = rows[1:]
@@ -550,7 +588,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         else:
             return None
 
-    def checkpoint(self, eager=True):
+    def checkpoint(self, eager: bool = True) -> "DataFrame":
         """Returns a checkpointed version of this :class:`DataFrame`. 
Checkpointing can be used to
         truncate the logical plan of this :class:`DataFrame`, which is 
especially useful in
         iterative algorithms where the plan may grow exponentially. It will be 
saved to files
@@ -570,7 +608,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         jdf = self._jdf.checkpoint(eager)
         return DataFrame(jdf, self.sql_ctx)
 
-    def localCheckpoint(self, eager=True):
+    def localCheckpoint(self, eager: bool = True) -> "DataFrame":
         """Returns a locally checkpointed version of this :class:`DataFrame`. 
Checkpointing can be
         used to truncate the logical plan of this :class:`DataFrame`, which is 
especially useful in
         iterative algorithms where the plan may grow exponentially. Local 
checkpoints are
@@ -590,7 +628,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         jdf = self._jdf.localCheckpoint(eager)
         return DataFrame(jdf, self.sql_ctx)
 
-    def withWatermark(self, eventTime, delayThreshold):
+    def withWatermark(self, eventTime: str, delayThreshold: str) -> 
"DataFrame":
         """Defines an event time watermark for this :class:`DataFrame`. A 
watermark tracks a point
         in time before which we assume no more late data is going to arrive.
 
@@ -634,7 +672,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         jdf = self._jdf.withWatermark(eventTime, delayThreshold)
         return DataFrame(jdf, self.sql_ctx)
 
-    def hint(self, name, *parameters):
+    def hint(
+        self, name: str, *parameters: Union["PrimitiveType", 
List["PrimitiveType"]]
+    ) -> "DataFrame":
         """Specifies some hint on the current :class:`DataFrame`.
 
         .. versionadded:: 2.2.0
@@ -660,7 +700,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         +----+---+------+
         """
         if len(parameters) == 1 and isinstance(parameters[0], list):
-            parameters = parameters[0]
+            parameters = parameters[0]  # type: ignore[assignment]
 
         if not isinstance(name, str):
             raise TypeError("name should be provided as str, got 
{0}".format(type(name)))
@@ -675,7 +715,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         jdf = self._jdf.hint(name, self._jseq(parameters))
         return DataFrame(jdf, self.sql_ctx)
 
-    def count(self):
+    def count(self) -> int:
         """Returns the number of rows in this :class:`DataFrame`.
 
         .. versionadded:: 1.3.0
@@ -687,7 +727,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         return int(self._jdf.count())
 
-    def collect(self):
+    def collect(self) -> List[Row]:
         """Returns all the records as a list of :class:`Row`.
 
         .. versionadded:: 1.3.0
@@ -701,7 +741,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             sock_info = self._jdf.collectToPython()
         return list(_load_from_socket(sock_info, 
BatchedSerializer(PickleSerializer())))
 
-    def toLocalIterator(self, prefetchPartitions=False):
+    def toLocalIterator(self, prefetchPartitions: bool = False) -> 
Iterator[Row]:
         """
         Returns an iterator that contains all of the rows in this 
:class:`DataFrame`.
         The iterator will consume as much memory as the largest partition in 
this
@@ -724,7 +764,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             sock_info = self._jdf.toPythonIterator(prefetchPartitions)
         return _local_iterator_from_socket(sock_info, 
BatchedSerializer(PickleSerializer()))
 
-    def limit(self, num):
+    def limit(self, num: int) -> "DataFrame":
         """Limits the result count to the number specified.
 
         .. versionadded:: 1.3.0
@@ -739,7 +779,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         jdf = self._jdf.limit(num)
         return DataFrame(jdf, self.sql_ctx)
 
-    def take(self, num):
+    def take(self, num: int) -> List[Row]:
         """Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
 
         .. versionadded:: 1.3.0
@@ -751,7 +791,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         return self.limit(num).collect()
 
-    def tail(self, num):
+    def tail(self, num: int) -> List[Row]:
         """
         Returns the last ``num`` rows as a :class:`list` of :class:`Row`.
 
@@ -769,7 +809,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             sock_info = self._jdf.tailToPython(num)
         return list(_load_from_socket(sock_info, 
BatchedSerializer(PickleSerializer())))
 
-    def foreach(self, f):
+    def foreach(self, f: Callable[[Row], None]) -> None:
         """Applies the ``f`` function to all :class:`Row` of this 
:class:`DataFrame`.
 
         This is a shorthand for ``df.rdd.foreach()``.
@@ -784,7 +824,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         self.rdd.foreach(f)
 
-    def foreachPartition(self, f):
+    def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None:
         """Applies the ``f`` function to each partition of this 
:class:`DataFrame`.
 
         This a shorthand for ``df.rdd.foreachPartition()``.
@@ -798,9 +838,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         ...         print(person.name)
         >>> df.foreachPartition(f)
         """
-        self.rdd.foreachPartition(f)
+        self.rdd.foreachPartition(f)  # type: ignore[arg-type]
 
-    def cache(self):
+    def cache(self) -> "DataFrame":
         """Persists the :class:`DataFrame` with the default storage level 
(`MEMORY_AND_DISK`).
 
         .. versionadded:: 1.3.0
@@ -813,7 +853,12 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         self._jdf.cache()
         return self
 
-    def persist(self, storageLevel=StorageLevel.MEMORY_AND_DISK_DESER):
+    def persist(
+        self,
+        storageLevel: StorageLevel = (
+            StorageLevel.MEMORY_AND_DISK_DESER  # type: ignore[attr-defined]
+        )
+    ) -> "DataFrame":
         """Sets the storage level to persist the contents of the 
:class:`DataFrame` across
         operations after the first time it is computed. This can only be used 
to assign
         a new storage level if the :class:`DataFrame` does not have a storage 
level set yet.
@@ -826,12 +871,12 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         The default storage level has changed to `MEMORY_AND_DISK_DESER` to 
match Scala in 3.0.
         """
         self.is_cached = True
-        javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+        javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)  # 
type: ignore[attr-defined]
         self._jdf.persist(javaStorageLevel)
         return self
 
     @property
-    def storageLevel(self):
+    def storageLevel(self) -> StorageLevel:
         """Get the :class:`DataFrame`'s current storage level.
 
         .. versionadded:: 2.1.0
@@ -853,7 +898,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
                                      java_storage_level.replication())
         return storage_level
 
-    def unpersist(self, blocking=False):
+    def unpersist(self, blocking: bool = False) -> "DataFrame":
         """Marks the :class:`DataFrame` as non-persistent, and remove all 
blocks for it from
         memory and disk.
 
@@ -867,7 +912,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         self._jdf.unpersist(blocking)
         return self
 
-    def coalesce(self, numPartitions):
+    def coalesce(self, numPartitions: int) -> "DataFrame":
         """
         Returns a new :class:`DataFrame` that has exactly `numPartitions` 
partitions.
 
@@ -898,7 +943,17 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         return DataFrame(self._jdf.coalesce(numPartitions), self.sql_ctx)
 
-    def repartition(self, numPartitions, *cols):
+    @overload
+    def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> 
"DataFrame":
+        ...
+
+    @overload
+    def repartition(self, *cols: "ColumnOrName") -> "DataFrame":
+        ...
+
+    def repartition(  # type: ignore[misc]
+        self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
+    ) -> "DataFrame":
         """
         Returns a new :class:`DataFrame` partitioned by the given partitioning 
expressions. The
         resulting :class:`DataFrame` is hash partitioned.
@@ -967,7 +1022,17 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         else:
             raise TypeError("numPartitions should be an int or Column")
 
-    def repartitionByRange(self, numPartitions, *cols):
+    @overload
+    def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> 
"DataFrame":
+        ...
+
+    @overload
+    def repartitionByRange(self, *cols: "ColumnOrName") -> "DataFrame":
+        ...
+
+    def repartitionByRange(  # type: ignore[misc]
+        self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
+    ) -> "DataFrame":
         """
         Returns a new :class:`DataFrame` partitioned by the given partitioning 
expressions. The
         resulting :class:`DataFrame` is range partitioned.
@@ -1017,7 +1082,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         if isinstance(numPartitions, int):
             if len(cols) == 0:
-                return ValueError("At least one partition-by expression must 
be specified.")
+                raise ValueError("At least one partition-by expression must be 
specified.")
             else:
                 return DataFrame(
                     self._jdf.repartitionByRange(numPartitions, 
self._jcols(*cols)), self.sql_ctx)
@@ -1027,7 +1092,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         else:
             raise TypeError("numPartitions should be an int, string or Column")
 
-    def distinct(self):
+    def distinct(self) -> "DataFrame":
         """Returns a new :class:`DataFrame` containing the distinct rows in 
this :class:`DataFrame`.
 
         .. versionadded:: 1.3.0
@@ -1039,7 +1104,25 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         """
         return DataFrame(self._jdf.distinct(), self.sql_ctx)
 
-    def sample(self, withReplacement=None, fraction=None, seed=None):
+    @overload
+    def sample(self, fraction: float, seed: Optional[int] = ...) -> 
"DataFrame":
+        ...
+
+    @overload
+    def sample(
+        self,
+        withReplacement: Optional[bool],
+        fraction: float,
+        seed: Optional[int] = ...,
+    ) -> "DataFrame":
+        ...
+
+    def sample(  # type: ignore[misc]
+        self,
+        withReplacement: Optional[Union[float, bool]] = None,
+        fraction: Optional[Union[int, float]] = None,
+        seed: Optional[int] = None
+    ) -> "DataFrame":
         """Returns a sampled subset of this :class:`DataFrame`.
 
         .. versionadded:: 1.3.0
@@ -1105,7 +1188,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
 
         if is_withReplacement_omitted_args:
             if fraction is not None:
-                seed = fraction
+                seed = cast(int, fraction)
             fraction = withReplacement
             withReplacement = None
 
@@ -1114,7 +1197,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         jdf = self._jdf.sample(*args)
         return DataFrame(jdf, self.sql_ctx)
 
-    def sampleBy(self, col, fractions, seed=None):
+    def sampleBy(
+        self, col: "ColumnOrName", fractions: Dict[Any, float], seed: 
Optional[int] = None
+    ) -> "DataFrame":
         """
         Returns a stratified sample without replacement based on the
         fraction given on each stratum.
@@ -1167,7 +1252,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         seed = seed if seed is not None else random.randint(0, sys.maxsize)
         return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), 
seed), self.sql_ctx)
 
-    def randomSplit(self, weights, seed=None):
+    def randomSplit(self, weights: List[float], seed: Optional[int] = None) -> 
List["DataFrame"]:
         """Randomly splits this :class:`DataFrame` with the provided weights.
 
         .. versionadded:: 1.4.0
@@ -1193,11 +1278,13 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
             if w < 0.0:
                 raise ValueError("Weights must be positive. Found weight 
value: %s" % w)
         seed = seed if seed is not None else random.randint(0, sys.maxsize)
-        rdd_array = self._jdf.randomSplit(_to_list(self.sql_ctx._sc, weights), 
int(seed))
+        rdd_array = self._jdf.randomSplit(
+            _to_list(self.sql_ctx._sc, weights), int(seed)  # type: 
ignore[attr-defined]
+        )
         return [DataFrame(rdd, self.sql_ctx) for rdd in rdd_array]
 
     @property
-    def dtypes(self):
+    def dtypes(self) -> List[Tuple[str, str]]:
         """Returns all column names and their data types as a list.
 
         .. versionadded:: 1.3.0
@@ -1210,7 +1297,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         return [(str(f.name), f.dataType.simpleString()) for f in 
self.schema.fields]
 
     @property
-    def columns(self):
+    def columns(self) -> List[str]:
         """Returns all column names as a list.
 
         .. versionadded:: 1.3.0
@@ -1222,7 +1309,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         return [f.name for f in self.schema.fields]
 
-    def colRegex(self, colName):
+    def colRegex(self, colName: str) -> Column:
         """
         Selects column based on the column name specified as a regex and 
returns it
         as :class:`Column`.
@@ -1251,7 +1338,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         jc = self._jdf.colRegex(colName)
         return Column(jc)
 
-    def alias(self, alias):
+    def alias(self, alias: str) -> "DataFrame":
         """Returns a new :class:`DataFrame` with an alias set.
 
         .. versionadded:: 1.3.0
@@ -1274,7 +1361,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         assert isinstance(alias, str), "alias should be a string"
         return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx)
 
-    def crossJoin(self, other):
+    def crossJoin(self, other: "DataFrame") -> "DataFrame":
         """Returns the cartesian product with another :class:`DataFrame`.
 
         .. versionadded:: 2.1.0
@@ -1298,7 +1385,12 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         jdf = self._jdf.crossJoin(other._jdf)
         return DataFrame(jdf, self.sql_ctx)
 
-    def join(self, other, on=None, how=None):
+    def join(
+        self,
+        other: "DataFrame",
+        on: Optional[Union[str, List[str], Column, List[Column]]] = None,
+        how: Optional[str] = None
+    ) -> "DataFrame":
         """Joins with another :class:`DataFrame`, using the given join 
expression.
 
         .. versionadded:: 1.3.0
@@ -1342,14 +1434,14 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         """
 
         if on is not None and not isinstance(on, list):
-            on = [on]
+            on = [on]  # type: ignore[assignment]
 
         if on is not None:
             if isinstance(on[0], str):
-                on = self._jseq(on)
+                on = self._jseq(cast(List[str], on))
             else:
                 assert isinstance(on[0], Column), "on should be Column or list 
of Column"
-                on = reduce(lambda x, y: x.__and__(y), on)
+                on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on))
                 on = on._jc
 
         if on is None and how is None:
@@ -1366,16 +1458,16 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
     # TODO(SPARK-22947): Fix the DataFrame API.
     def _joinAsOf(
         self,
-        other,
-        leftAsOfColumn,
-        rightAsOfColumn,
-        on=None,
-        how=None,
+        other: "DataFrame",
+        leftAsOfColumn: Union[str, Column],
+        rightAsOfColumn: Union[str, Column],
+        on: Optional[Union[str, List[str], Column, List[Column]]] = None,
+        how: Optional[str] = None,
         *,
-        tolerance=None,
-        allowExactMatches=True,
-        direction="backward",
-    ):
+        tolerance: Optional[Column] = None,
+        allowExactMatches: bool = True,
+        direction: str = "backward",
+    ) -> "DataFrame":
         """
         Perform an as-of join.
 
@@ -1448,20 +1540,20 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         """
         if isinstance(leftAsOfColumn, str):
             leftAsOfColumn = self[leftAsOfColumn]
-        left_as_of_jcol = leftAsOfColumn._jc
+        left_as_of_jcol = cast(Column, leftAsOfColumn)._jc
         if isinstance(rightAsOfColumn, str):
             rightAsOfColumn = other[rightAsOfColumn]
-        right_as_of_jcol = rightAsOfColumn._jc
+        right_as_of_jcol = cast(Column, rightAsOfColumn)._jc
 
         if on is not None and not isinstance(on, list):
-            on = [on]
+            on = [on]  # type: ignore[assignment]
 
         if on is not None:
             if isinstance(on[0], str):
-                on = self._jseq(on)
+                on = self._jseq(cast(List[str], on))
             else:
                 assert isinstance(on[0], Column), "on should be Column or list 
of Column"
-                on = reduce(lambda x, y: x.__and__(y), on)
+                on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on))
                 on = on._jc
 
         if how is None:
@@ -1480,7 +1572,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         )
         return DataFrame(jdf, self.sql_ctx)
 
-    def sortWithinPartitions(self, *cols, **kwargs):
+    def sortWithinPartitions(
+        self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: 
Any
+    ) -> "DataFrame":
         """Returns a new :class:`DataFrame` with each partition sorted by the 
specified column(s).
 
         .. versionadded:: 1.6.0
@@ -1510,7 +1604,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
         return DataFrame(jdf, self.sql_ctx)
 
-    def sort(self, *cols, **kwargs):
+    def sort(
+        self, *cols: Union[str, Column, List[Union[str, Column]]], **kwargs: 
Any
+    ) -> "DataFrame":
         """Returns a new :class:`DataFrame` sorted by the specified column(s).
 
         .. versionadded:: 1.3.0
@@ -1548,15 +1644,19 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
 
     orderBy = sort
 
-    def _jseq(self, cols, converter=None):
+    def _jseq(
+        self,
+        cols: Sequence,
+        converter: Optional[Callable[..., Union["PrimitiveType", JavaObject]]] 
= None
+    ) -> JavaObject:
         """Return a JVM Seq of Columns from a list of Column or names"""
-        return _to_seq(self.sql_ctx._sc, cols, converter)
+        return _to_seq(self.sql_ctx._sc, cols, converter)  # type: 
ignore[attr-defined]
 
-    def _jmap(self, jm):
+    def _jmap(self, jm: Dict) -> JavaObject:
         """Return a JVM Scala Map from a dict"""
-        return _to_scala_map(self.sql_ctx._sc, jm)
+        return _to_scala_map(self.sql_ctx._sc, jm)  # type: 
ignore[attr-defined]
 
-    def _jcols(self, *cols):
+    def _jcols(self, *cols: "ColumnOrName") -> JavaObject:
         """Return a JVM Seq of Columns from a list of Column or column names
 
         If `cols` has only one list in it, cols[0] will be used as the list.
@@ -1565,7 +1665,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             cols = cols[0]
         return self._jseq(cols, _to_java_column)
 
-    def _sort_cols(self, cols, kwargs):
+    def _sort_cols(
+        self, cols: Sequence[Union[str, Column, List[Union[str, Column]]]], 
kwargs: Dict[str, Any]
+    ) -> JavaObject:
         """ Return a JVM Seq of Columns that describes the sort order
         """
         if not cols:
@@ -1584,7 +1686,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             raise TypeError("ascending can only be boolean or list, but got 
%s" % type(ascending))
         return self._jseq(jcols)
 
-    def describe(self, *cols):
+    def describe(self, *cols: Union[str, List[str]]) -> "DataFrame":
         """Computes basic statistics for numeric and string columns.
 
         .. versionadded:: 1.3.1
@@ -1628,11 +1730,11 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         DataFrame.summary
         """
         if len(cols) == 1 and isinstance(cols[0], list):
-            cols = cols[0]
+            cols = cols[0]  # type: ignore[assignment]
         jdf = self._jdf.describe(self._jseq(cols))
         return DataFrame(jdf, self.sql_ctx)
 
-    def summary(self, *statistics):
+    def summary(self, *statistics: str) -> "DataFrame":
         """Computes specified statistics for numeric and string columns. 
Available statistics are:
         - count
         - mean
@@ -1697,7 +1799,15 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         jdf = self._jdf.summary(self._jseq(statistics))
         return DataFrame(jdf, self.sql_ctx)
 
-    def head(self, n=None):
+    @overload
+    def head(self) -> Optional[Row]:
+        ...
+
+    @overload
+    def head(self, n: int) -> List[Row]:
+        ...
+
+    def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]:
         """Returns the first ``n`` rows.
 
         .. versionadded:: 1.3.0
@@ -1729,7 +1839,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             return rs[0] if rs else None
         return self.take(n)
 
-    def first(self):
+    def first(self) -> Optional[Row]:
         """Returns the first row as a :class:`Row`.
 
         .. versionadded:: 1.3.0
@@ -1741,7 +1851,15 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         """
         return self.head()
 
-    def __getitem__(self, item):
+    @overload
+    def __getitem__(self, item: Union[int, str]) -> Column:
+        ...
+
+    @overload
+    def __getitem__(self, item: Union[Column, List, Tuple]) -> "DataFrame":
+        ...
+
+    def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> 
Union[Column, "DataFrame"]:
         """Returns the column as a :class:`Column`.
 
         .. versionadded:: 1.3.0
@@ -1770,7 +1888,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         else:
             raise TypeError("unexpected item type: %s" % type(item))
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> Column:
         """Returns the :class:`Column` denoted by ``name``.
 
         .. versionadded:: 1.3.0
@@ -1786,7 +1904,15 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         jc = self._jdf.apply(name)
         return Column(jc)
 
-    def select(self, *cols):
+    @overload
+    def select(self, *cols: "ColumnOrName") -> "DataFrame":
+        ...
+
+    @overload
+    def select(self, __cols: Union[List[Column], List[str]]) -> "DataFrame":
+        ...
+
+    def select(self, *cols: "ColumnOrName") -> "DataFrame":  # type: 
ignore[misc]
         """Projects a set of expressions and returns a new :class:`DataFrame`.
 
         .. versionadded:: 1.3.0
@@ -1810,7 +1936,15 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         jdf = self._jdf.select(self._jcols(*cols))
         return DataFrame(jdf, self.sql_ctx)
 
-    def selectExpr(self, *expr):
+    @overload
+    def selectExpr(self, *expr: str) -> "DataFrame":
+        ...
+
+    @overload
+    def selectExpr(self, *expr: List[str]) -> "DataFrame":
+        ...
+
+    def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame":
         """Projects a set of SQL expressions and returns a new 
:class:`DataFrame`.
 
         This is a variant of :func:`select` that accepts SQL expressions.
@@ -1823,11 +1957,11 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)]
         """
         if len(expr) == 1 and isinstance(expr[0], list):
-            expr = expr[0]
+            expr = expr[0]  # type: ignore[assignment]
         jdf = self._jdf.selectExpr(self._jseq(expr))
         return DataFrame(jdf, self.sql_ctx)
 
-    def filter(self, condition):
+    def filter(self, condition: "ColumnOrName") -> "DataFrame":
         """Filters rows using the given condition.
 
         :func:`where` is an alias for :func:`filter`.
@@ -1860,7 +1994,15 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
             raise TypeError("condition should be string or Column")
         return DataFrame(jdf, self.sql_ctx)
 
-    def groupBy(self, *cols):
+    @overload
+    def groupBy(self, *cols: "ColumnOrName") -> "GroupedData":
+        ...
+
+    @overload
+    def groupBy(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+        ...
+
+    def groupBy(self, *cols: "ColumnOrName") -> "GroupedData":  # type: 
ignore[misc]
         """Groups the :class:`DataFrame` using the specified columns,
         so we can run aggregation on them. See :class:`GroupedData`
         for all the available aggregate functions.
@@ -1890,7 +2032,15 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         from pyspark.sql.group import GroupedData
         return GroupedData(jgd, self)
 
-    def rollup(self, *cols):
+    @overload
+    def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
+        ...
+
+    @overload
+    def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+        ...
+
+    def rollup(self, *cols: "ColumnOrName") -> "GroupedData":  # type: 
ignore[misc]
         """
         Create a multi-dimensional rollup for the current :class:`DataFrame` 
using
         the specified columns, so we can run aggregation on them.
@@ -1914,7 +2064,15 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         from pyspark.sql.group import GroupedData
         return GroupedData(jgd, self)
 
-    def cube(self, *cols):
+    @overload
+    def cube(self, *cols: "ColumnOrName") -> "GroupedData":
+        ...
+
+    @overload
+    def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
+        ...
+
+    def cube(self, *cols: "ColumnOrName") -> "GroupedData":  # type: 
ignore[misc]
         """
         Create a multi-dimensional cube for the current :class:`DataFrame` 
using
         the specified columns, so we can run aggregations on them.
@@ -1940,7 +2098,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         from pyspark.sql.group import GroupedData
         return GroupedData(jgd, self)
 
-    def agg(self, *exprs):
+    def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
         """ Aggregate on the entire :class:`DataFrame` without groups
         (shorthand for ``df.groupBy().agg()``).
 
@@ -1954,10 +2112,10 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         >>> df.agg(F.min(df.age)).collect()
         [Row(min(age)=2)]
         """
-        return self.groupBy().agg(*exprs)
+        return self.groupBy().agg(*exprs)  # type: ignore[arg-type]
 
     @since(3.3)
-    def observe(self, observation, *exprs):
+    def observe(self, observation: "Observation", *exprs: Column) -> 
"DataFrame":
         """Observe (named) metrics through an :class:`Observation` instance.
 
         A user can retrieve the metrics by accessing `Observation.get`.
@@ -1996,7 +2154,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         return observation._on(self, *exprs)
 
     @since(2.0)
-    def union(self, other):
+    def union(self, other: "DataFrame") -> "DataFrame":
         """ Return a new :class:`DataFrame` containing union of rows in this 
and another
         :class:`DataFrame`.
 
@@ -2008,7 +2166,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         return DataFrame(self._jdf.union(other._jdf), self.sql_ctx)
 
     @since(1.3)
-    def unionAll(self, other):
+    def unionAll(self, other: "DataFrame") -> "DataFrame":
         """ Return a new :class:`DataFrame` containing union of rows in this 
and another
         :class:`DataFrame`.
 
@@ -2019,7 +2177,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         return self.union(other)
 
-    def unionByName(self, other, allowMissingColumns=False):
+    def unionByName(self, other: "DataFrame", allowMissingColumns: bool = 
False) -> "DataFrame":
         """ Returns a new :class:`DataFrame` containing union of rows in this 
and another
         :class:`DataFrame`.
 
@@ -2065,7 +2223,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         return DataFrame(self._jdf.unionByName(other._jdf, 
allowMissingColumns), self.sql_ctx)
 
     @since(1.3)
-    def intersect(self, other):
+    def intersect(self, other: "DataFrame") -> "DataFrame":
         """ Return a new :class:`DataFrame` containing rows only in
         both this :class:`DataFrame` and another :class:`DataFrame`.
 
@@ -2073,7 +2231,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
 
-    def intersectAll(self, other):
+    def intersectAll(self, other: "DataFrame") -> "DataFrame":
         """ Return a new :class:`DataFrame` containing rows in both this 
:class:`DataFrame`
         and another :class:`DataFrame` while preserving duplicates.
 
@@ -2100,7 +2258,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         return DataFrame(self._jdf.intersectAll(other._jdf), self.sql_ctx)
 
     @since(1.3)
-    def subtract(self, other):
+    def subtract(self, other: "DataFrame") -> "DataFrame":
         """ Return a new :class:`DataFrame` containing rows in this 
:class:`DataFrame`
         but not in another :class:`DataFrame`.
 
@@ -2109,7 +2267,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         return DataFrame(getattr(self._jdf, "except")(other._jdf), 
self.sql_ctx)
 
-    def dropDuplicates(self, subset=None):
+    def dropDuplicates(self, subset: Optional[List[str]] = None) -> 
"DataFrame":
         """Return a new :class:`DataFrame` with duplicate rows removed,
         optionally only considering certain columns.
 
@@ -2155,7 +2313,12 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
             jdf = self._jdf.dropDuplicates(self._jseq(subset))
         return DataFrame(jdf, self.sql_ctx)
 
-    def dropna(self, how='any', thresh=None, subset=None):
+    def dropna(
+        self,
+        how: str = 'any',
+        thresh: Optional[int] = None,
+        subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None
+    ) -> "DataFrame":
         """Returns a new :class:`DataFrame` omitting rows with null values.
         :func:`DataFrame.dropna` and :func:`DataFrameNaFunctions.drop` are 
aliases of each other.
 
@@ -2198,7 +2361,23 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
 
         return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), 
self.sql_ctx)
 
-    def fillna(self, value, subset=None):
+    @overload
+    def fillna(
+        self,
+        value: "LiteralType",
+        subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ...,
+    ) -> "DataFrame":
+        ...
+
+    @overload
+    def fillna(self, value: Dict[str, "LiteralType"]) -> "DataFrame":
+        ...
+
+    def fillna(
+        self,
+        value: Union["LiteralType", Dict[str, "LiteralType"]],
+        subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None
+    ) -> "DataFrame":
         """Replace null values, alias for ``na.fill()``.
         :func:`DataFrame.fillna` and :func:`DataFrameNaFunctions.fill` are 
aliases of each other.
 
@@ -2269,7 +2448,49 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
 
             return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), 
self.sql_ctx)
 
-    def replace(self, to_replace, value=_NoValue, subset=None):
+    @overload
+    def replace(
+        self,
+        to_replace: "LiteralType",
+        value: "OptionalPrimitiveType",
+        subset: Optional[List[str]] = ...,
+    ) -> "DataFrame":
+        ...
+
+    @overload
+    def replace(
+        self,
+        to_replace: List["LiteralType"],
+        value: List["OptionalPrimitiveType"],
+        subset: Optional[List[str]] = ...,
+    ) -> "DataFrame":
+        ...
+
+    @overload
+    def replace(
+        self,
+        to_replace: Dict["LiteralType", "OptionalPrimitiveType"],
+        subset: Optional[List[str]] = ...,
+    ) -> "DataFrame":
+        ...
+
+    @overload
+    def replace(
+        self,
+        to_replace: List["LiteralType"],
+        value: "OptionalPrimitiveType",
+        subset: Optional[List[str]] = ...,
+    ) -> "DataFrame":
+        ...
+
+    def replace(  # type: ignore[misc]
+        self,
+        to_replace: Union[
+            "LiteralType", List["LiteralType"], Dict["LiteralType", 
"OptionalPrimitiveType"]
+        ],
+        value: Optional[Union["OptionalPrimitiveType", 
List["OptionalPrimitiveType"]]] = _NoValue,
+        subset: Optional[List[str]] = None
+    ) -> "DataFrame":
         """Returns a new :class:`DataFrame` replacing a value with another 
value.
         :func:`DataFrame.replace` and :func:`DataFrameNaFunctions.replace` are
         aliases of each other.
@@ -2348,7 +2569,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
                 raise TypeError("value argument is required when to_replace is 
not a dictionary.")
 
         # Helper functions
-        def all_of(types):
+        def all_of(types: Union[Type, Tuple[Type, ...]]) -> 
Callable[[Iterable], bool]:
             """Given a type or tuple of types and a sequence of xs
             check if each x is instance of type(s)
 
@@ -2357,7 +2578,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             >>> all_of(str)(["a", 1])
             False
             """
-            def all_of_(xs):
+            def all_of_(xs: Iterable) -> bool:
                 return all(isinstance(x, types) for x in xs)
             return all_of_
 
@@ -2415,7 +2636,30 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
             return DataFrame(
                 self._jdf.na().replace(self._jseq(subset), 
self._jmap(rep_dict)), self.sql_ctx)
 
-    def approxQuantile(self, col, probabilities, relativeError):
+    @overload
+    def approxQuantile(
+        self,
+        col: str,
+        probabilities: Union[List[float], Tuple[float]],
+        relativeError: float,
+    ) -> List[float]:
+        ...
+
+    @overload
+    def approxQuantile(
+        self,
+        col: Union[List[str], Tuple[str]],
+        probabilities: Union[List[float], Tuple[float]],
+        relativeError: float,
+    ) -> List[List[float]]:
+        ...
+
+    def approxQuantile(
+        self,
+        col: Union[str, List[str], Tuple[str]],
+        probabilities: Union[List[float], Tuple[float]],
+        relativeError: float
+    ) -> Union[List[float], List[List[float]]]:
         """
         Calculates the approximate quantiles of numerical columns of a
         :class:`DataFrame`.
@@ -2474,7 +2718,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         if isinstance(col, tuple):
             col = list(col)
         elif isStr:
-            col = [col]
+            col = [cast(str, col)]
 
         for c in col:
             if not isinstance(c, str):
@@ -2500,7 +2744,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         jaq_list = [list(j) for j in jaq]
         return jaq_list[0] if isStr else jaq_list
 
-    def corr(self, col1, col2, method=None):
+    def corr(self, col1: str, col2: str, method: Optional[str] = None) -> 
float:
         """
         Calculates the correlation of two columns of a :class:`DataFrame` as a 
double value.
         Currently only supports the Pearson Correlation Coefficient.
@@ -2528,7 +2772,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
                              "coefficient is supported.")
         return self._jdf.stat().corr(col1, col2, method)
 
-    def cov(self, col1, col2):
+    def cov(self, col1: str, col2: str) -> float:
         """
         Calculate the sample covariance for the given columns, specified by 
their names, as a
         double value. :func:`DataFrame.cov` and 
:func:`DataFrameStatFunctions.cov` are aliases.
@@ -2548,7 +2792,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             raise TypeError("col2 should be a string.")
         return self._jdf.stat().cov(col1, col2)
 
-    def crosstab(self, col1, col2):
+    def crosstab(self, col1: str, col2: str) -> "DataFrame":
         """
         Computes a pair-wise frequency table of the given columns. Also known 
as a contingency
         table. The number of distinct values for each column should be less 
than 1e4. At most 1e6
@@ -2575,7 +2819,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             raise TypeError("col2 should be a string.")
         return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
 
-    def freqItems(self, cols, support=None):
+    def freqItems(
+        self, cols: Union[List[str], Tuple[str]], support: Optional[float] = 
None
+    ) -> "DataFrame":
         """
         Finding frequent items for columns, possibly with false positives. 
Using the
         frequent element count algorithm described in
@@ -2607,7 +2853,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             support = 0.01
         return DataFrame(self._jdf.stat().freqItems(_to_seq(self._sc, cols), 
support), self.sql_ctx)
 
-    def withColumn(self, colName, col):
+    def withColumn(self, colName: str, col: Column) -> "DataFrame":
         """
         Returns a new :class:`DataFrame` by adding a column or replacing the
         existing column that has the same name.
@@ -2641,7 +2887,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             raise TypeError("col should be Column")
         return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx)
 
-    def withColumnRenamed(self, existing, new):
+    def withColumnRenamed(self, existing: str, new: str) -> "DataFrame":
         """Returns a new :class:`DataFrame` by renaming an existing column.
         This is a no-op if schema doesn't contain the given column name.
 
@@ -2661,7 +2907,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         return DataFrame(self._jdf.withColumnRenamed(existing, new), 
self.sql_ctx)
 
-    def withMetadata(self, columnName, metadata):
+    def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> 
"DataFrame":
         """Returns a new :class:`DataFrame` by updating an existing column 
with metadata.
 
         .. versionadded:: 3.3.0
@@ -2681,12 +2927,20 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         """
         if not isinstance(metadata, dict):
             raise TypeError("metadata should be a dict")
-        sc = SparkContext._active_spark_context
+        sc = SparkContext._active_spark_context  # type: ignore[attr-defined]
         jmeta = sc._jvm.org.apache.spark.sql.types.Metadata.fromJson(
             json.dumps(metadata))
         return DataFrame(self._jdf.withMetadata(columnName, jmeta), 
self.sql_ctx)
 
-    def drop(self, *cols):
+    @overload
+    def drop(self, cols: "ColumnOrName") -> "DataFrame":
+        ...
+
+    @overload
+    def drop(self, *cols: str) -> "DataFrame":
+        ...
+
+    def drop(self, *cols: "ColumnOrName") -> "DataFrame":  # type: ignore[misc]
         """Returns a new :class:`DataFrame` that drops the specified column.
         This is a no-op if schema doesn't contain the given column name(s).
 
@@ -2730,7 +2984,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
 
         return DataFrame(jdf, self.sql_ctx)
 
-    def toDF(self, *cols):
+    def toDF(self, *cols: "ColumnOrName") -> "DataFrame":
         """Returns a new :class:`DataFrame` that with new specified column 
names
 
         Parameters
@@ -2746,7 +3000,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         jdf = self._jdf.toDF(self._jseq(cols))
         return DataFrame(jdf, self.sql_ctx)
 
-    def transform(self, func):
+    def transform(self, func: Callable[["DataFrame"], "DataFrame"]) -> 
"DataFrame":
         """Returns a new :class:`DataFrame`. Concise syntax for chaining 
custom transformations.
 
         .. versionadded:: 3.0.0
@@ -2777,7 +3031,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
                                               "should have been DataFrame." % 
type(result)
         return result
 
-    def sameSemantics(self, other):
+    def sameSemantics(self, other: "DataFrame") -> bool:
         """
         Returns `True` when the logical query plans inside both 
:class:`DataFrame`\\s are equal and
         therefore return same results.
@@ -2811,7 +3065,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
                             % type(other))
         return self._jdf.sameSemantics(other._jdf)
 
-    def semanticHash(self):
+    def semanticHash(self) -> int:
         """
         Returns a hash code of the logical query plan against this 
:class:`DataFrame`.
 
@@ -2833,7 +3087,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         return self._jdf.semanticHash()
 
-    def inputFiles(self):
+    def inputFiles(self) -> List[str]:
         """
         Returns a best-effort snapshot of the files that compose this 
:class:`DataFrame`.
         This method simply asks each constituent BaseRelation for its 
respective files and
@@ -2870,7 +3124,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         sinceversion=1.4,
         doc=":func:`drop_duplicates` is an alias for :func:`dropDuplicates`.")
 
-    def writeTo(self, table):
+    def writeTo(self, table: str) -> DataFrameWriterV2:
         """
         Create a write configuration builder for v2 sources.
 
@@ -2889,7 +3143,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         """
         return DataFrameWriterV2(self, table)
 
-    def to_pandas_on_spark(self, index_col=None):
+    def to_pandas_on_spark(
+        self, index_col: Optional[Union[str, List[str]]] = None
+    ) -> "PandasOnSparkDataFrame":
         """
         Converts the existing DataFrame into a pandas-on-Spark DataFrame.
 
@@ -2935,17 +3191,20 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         c        3
         """
         from pyspark.pandas.namespace import _get_index_map
-        from pyspark.pandas.frame import DataFrame
+        from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
         from pyspark.pandas.internal import InternalFrame
 
         index_spark_columns, index_names = _get_index_map(self, index_col)
         internal = InternalFrame(
-            spark_frame=self, index_spark_columns=index_spark_columns, 
index_names=index_names
+            spark_frame=self, index_spark_columns=index_spark_columns,
+            index_names=index_names  # type: ignore[arg-type]
         )
-        return DataFrame(internal)
+        return PandasOnSparkDataFrame(internal)
 
     # Keep to_koalas for backward compatibility for now.
-    def to_koalas(self, index_col=None):
+    def to_koalas(
+        self, index_col: Optional[Union[str, List[str]]] = None
+    ) -> "PandasOnSparkDataFrame":
         warnings.warn(
             "DataFrame.to_koalas is deprecated. Use 
DataFrame.to_pandas_on_spark instead.",
             FutureWarning,
@@ -2953,11 +3212,11 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         return self.to_pandas_on_spark(index_col)
 
 
-def _to_scala_map(sc, jm):
+def _to_scala_map(sc: SparkContext, jm: Dict) -> JavaObject:
     """
     Convert a dict into a JVM Map.
     """
-    return sc._jvm.PythonUtils.toScalaMap(jm)
+    return sc._jvm.PythonUtils.toScalaMap(jm)  # type: ignore[attr-defined]
 
 
 class DataFrameNaFunctions(object):
@@ -2966,21 +3225,70 @@ class DataFrameNaFunctions(object):
     .. versionadded:: 1.4
     """
 
-    def __init__(self, df):
+    def __init__(self, df: DataFrame):
         self.df = df
 
-    def drop(self, how='any', thresh=None, subset=None):
+    def drop(
+        self, how: str = 'any', thresh: Optional[int] = None, subset: 
Optional[List[str]] = None
+    ) -> DataFrame:
         return self.df.dropna(how=how, thresh=thresh, subset=subset)
 
     drop.__doc__ = DataFrame.dropna.__doc__
 
-    def fill(self, value, subset=None):
-        return self.df.fillna(value=value, subset=subset)
+    @overload
+    def fill(
+        self, value: "LiteralType", subset: Optional[List[str]] = ...
+    ) -> DataFrame:
+        ...
+
+    @overload
+    def fill(self, value: Dict[str, "LiteralType"]) -> DataFrame:
+        ...
+
+    def fill(
+        self,
+        value: Union["LiteralType", Dict[str, "LiteralType"]],
+        subset: Optional[List[str]] = None
+    ) -> DataFrame:
+        return self.df.fillna(value=value, subset=subset)  # type: 
ignore[arg-type]
 
     fill.__doc__ = DataFrame.fillna.__doc__
 
-    def replace(self, to_replace, value=_NoValue, subset=None):
-        return self.df.replace(to_replace, value, subset)
+    @overload
+    def replace(
+        self,
+        to_replace: List["LiteralType"],
+        value: List["OptionalPrimitiveType"],
+        subset: Optional[List[str]] = ...,
+    ) -> DataFrame:
+        ...
+
+    @overload
+    def replace(
+        self,
+        to_replace: Dict["LiteralType", "OptionalPrimitiveType"],
+        subset: Optional[List[str]] = ...,
+    ) -> DataFrame:
+        ...
+
+    @overload
+    def replace(
+        self,
+        to_replace: List["LiteralType"],
+        value: "OptionalPrimitiveType",
+        subset: Optional[List[str]] = ...,
+    ) -> DataFrame:
+        ...
+
+    def replace(  # type: ignore[misc]
+        self,
+        to_replace: Union[
+            List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"]
+        ],
+        value: Optional[Union["OptionalPrimitiveType", 
List["OptionalPrimitiveType"]]] = _NoValue,
+        subset: Optional[List[str]] = None
+    ) -> DataFrame:
+        return self.df.replace(to_replace, value, subset)  # type: 
ignore[arg-type]
 
     replace.__doc__ = DataFrame.replace.__doc__
 
@@ -2991,41 +3299,66 @@ class DataFrameStatFunctions(object):
     .. versionadded:: 1.4
     """
 
-    def __init__(self, df):
+    def __init__(self, df: DataFrame):
         self.df = df
 
-    def approxQuantile(self, col, probabilities, relativeError):
+    @overload
+    def approxQuantile(
+        self,
+        col: str,
+        probabilities: Union[List[float], Tuple[float]],
+        relativeError: float,
+    ) -> List[float]:
+        ...
+
+    @overload
+    def approxQuantile(
+        self,
+        col: Union[List[str], Tuple[str]],
+        probabilities: Union[List[float], Tuple[float]],
+        relativeError: float,
+    ) -> List[List[float]]:
+        ...
+
+    def approxQuantile(  # type: ignore[misc]
+        self,
+        col: Union[str, List[str], Tuple[str]],
+        probabilities: Union[List[float], Tuple[float]],
+        relativeError: float
+    ) -> Union[List[float], List[List[float]]]:
         return self.df.approxQuantile(col, probabilities, relativeError)
 
     approxQuantile.__doc__ = DataFrame.approxQuantile.__doc__
 
-    def corr(self, col1, col2, method=None):
+    def corr(self, col1: str, col2: str, method: Optional[str] = None) -> 
float:
         return self.df.corr(col1, col2, method)
 
     corr.__doc__ = DataFrame.corr.__doc__
 
-    def cov(self, col1, col2):
+    def cov(self, col1: str, col2: str) -> float:
         return self.df.cov(col1, col2)
 
     cov.__doc__ = DataFrame.cov.__doc__
 
-    def crosstab(self, col1, col2):
+    def crosstab(self, col1: str, col2: str) -> DataFrame:
         return self.df.crosstab(col1, col2)
 
     crosstab.__doc__ = DataFrame.crosstab.__doc__
 
-    def freqItems(self, cols, support=None):
+    def freqItems(self, cols: List[str], support: Optional[float] = None) -> 
DataFrame:
         return self.df.freqItems(cols, support)
 
     freqItems.__doc__ = DataFrame.freqItems.__doc__
 
-    def sampleBy(self, col, fractions, seed=None):
+    def sampleBy(
+        self, col: str, fractions: Dict[Any, float], seed: Optional[int] = None
+    ) -> DataFrame:
         return self.df.sampleBy(col, fractions, seed)
 
     sampleBy.__doc__ = DataFrame.sampleBy.__doc__
 
 
-def _test():
+def _test() -> None:
     import doctest
     from pyspark.context import SparkContext
     from pyspark.sql import Row, SQLContext, SparkSession
diff --git a/python/pyspark/sql/dataframe.pyi b/python/pyspark/sql/dataframe.pyi
deleted file mode 100644
index d903a79..0000000
--- a/python/pyspark/sql/dataframe.pyi
+++ /dev/null
@@ -1,351 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import overload
-from typing import (
-    Any,
-    Callable,
-    Dict,
-    Iterator,
-    List,
-    Optional,
-    Tuple,
-    Union,
-)
-
-from py4j.java_gateway import JavaObject  # type: ignore[import]
-
-from pyspark.sql._typing import ColumnOrName, LiteralType, 
OptionalPrimitiveType
-from pyspark._typing import PrimitiveType
-from pyspark.sql.types import (  # noqa: F401
-    StructType,
-    StructField,
-    StringType,
-    IntegerType,
-    Row,
-)  # noqa: F401
-from pyspark.sql.context import SQLContext
-from pyspark.sql.group import GroupedData
-from pyspark.sql.observation import Observation
-from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
-from pyspark.sql.streaming import DataStreamWriter
-from pyspark.sql.column import Column
-from pyspark.rdd import RDD
-from pyspark.storagelevel import StorageLevel
-
-from pyspark.sql.pandas.conversion import PandasConversionMixin
-from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
-from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
-
-class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
-    sql_ctx: SQLContext
-    is_cached: bool
-    def __init__(self, jdf: JavaObject, sql_ctx: SQLContext) -> None: ...
-    @property
-    def rdd(self) -> RDD[Row]: ...
-    @property
-    def na(self) -> DataFrameNaFunctions: ...
-    @property
-    def stat(self) -> DataFrameStatFunctions: ...
-    def toJSON(self, use_unicode: bool = ...) -> RDD[str]: ...
-    def registerTempTable(self, name: str) -> None: ...
-    def createTempView(self, name: str) -> None: ...
-    def createOrReplaceTempView(self, name: str) -> None: ...
-    def createGlobalTempView(self, name: str) -> None: ...
-    @property
-    def write(self) -> DataFrameWriter: ...
-    @property
-    def writeStream(self) -> DataStreamWriter: ...
-    @property
-    def schema(self) -> StructType: ...
-    def printSchema(self) -> None: ...
-    def explain(
-        self, extended: Optional[Union[bool, str]] = ..., mode: Optional[str] 
= ...
-    ) -> None: ...
-    def exceptAll(self, other: DataFrame) -> DataFrame: ...
-    def isLocal(self) -> bool: ...
-    @property
-    def isStreaming(self) -> bool: ...
-    def show(
-        self, n: int = ..., truncate: Union[bool, int] = ..., vertical: bool = 
...
-    ) -> None: ...
-    def checkpoint(self, eager: bool = ...) -> DataFrame: ...
-    def localCheckpoint(self, eager: bool = ...) -> DataFrame: ...
-    def withWatermark(
-        self, eventTime: str, delayThreshold: str
-    ) -> DataFrame: ...
-    def hint(self, name: str, *parameters: Union[PrimitiveType, 
List[PrimitiveType]]) -> DataFrame: ...
-    def count(self) -> int: ...
-    def collect(self) -> List[Row]: ...
-    def toLocalIterator(self, prefetchPartitions: bool = ...) -> 
Iterator[Row]: ...
-    def limit(self, num: int) -> DataFrame: ...
-    def take(self, num: int) -> List[Row]: ...
-    def tail(self, num: int) -> List[Row]: ...
-    def foreach(self, f: Callable[[Row], None]) -> None: ...
-    def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: ...
-    def cache(self) -> DataFrame: ...
-    def persist(self, storageLevel: StorageLevel = ...) -> DataFrame: ...
-    @property
-    def storageLevel(self) -> StorageLevel: ...
-    def unpersist(self, blocking: bool = ...) -> DataFrame: ...
-    def coalesce(self, numPartitions: int) -> DataFrame: ...
-    @overload
-    def repartition(self, numPartitions: int, *cols: ColumnOrName) -> 
DataFrame: ...
-    @overload
-    def repartition(self, *cols: ColumnOrName) -> DataFrame: ...
-    @overload
-    def repartitionByRange(
-        self, numPartitions: int, *cols: ColumnOrName
-    ) -> DataFrame: ...
-    @overload
-    def repartitionByRange(self, *cols: ColumnOrName) -> DataFrame: ...
-    def distinct(self) -> DataFrame: ...
-    @overload
-    def sample(self, fraction: float, seed: Optional[int] = ...) -> DataFrame: 
...
-    @overload
-    def sample(
-        self,
-        withReplacement: Optional[bool],
-        fraction: float,
-        seed: Optional[int] = ...,
-    ) -> DataFrame: ...
-    def sampleBy(
-        self, col: ColumnOrName, fractions: Dict[Any, float], seed: 
Optional[int] = ...
-    ) -> DataFrame: ...
-    def randomSplit(
-        self, weights: List[float], seed: Optional[int] = ...
-    ) -> List[DataFrame]: ...
-    @property
-    def dtypes(self) -> List[Tuple[str, str]]: ...
-    @property
-    def columns(self) -> List[str]: ...
-    def colRegex(self, colName: str) -> Column: ...
-    def alias(self, alias: str) -> DataFrame: ...
-    def crossJoin(self, other: DataFrame) -> DataFrame: ...
-    def join(
-        self,
-        other: DataFrame,
-        on: Optional[Union[str, List[str], Column, List[Column]]] = ...,
-        how: Optional[str] = ...,
-    ) -> DataFrame: ...
-    def sortWithinPartitions(
-        self,
-        *cols: Union[str, Column, List[Union[str, Column]]],
-        ascending: Union[bool, List[bool]] = ...
-    ) -> DataFrame: ...
-    def sort(
-        self,
-        *cols: Union[str, Column, List[Union[str, Column]]],
-        ascending: Union[bool, List[bool]] = ...
-    ) -> DataFrame: ...
-    def orderBy(
-        self,
-        *cols: Union[str, Column, List[Union[str, Column]]],
-        ascending: Union[bool, List[bool]] = ...
-    ) -> DataFrame: ...
-    def describe(self, *cols: Union[str, List[str]]) -> DataFrame: ...
-    def summary(self, *statistics: str) -> DataFrame: ...
-    @overload
-    def head(self) -> Row: ...
-    @overload
-    def head(self, n: int) -> List[Row]: ...
-    def first(self) -> Row: ...
-    def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> 
Column: ...
-    def __getattr__(self, name: str) -> Column: ...
-    @overload
-    def select(self, *cols: ColumnOrName) -> DataFrame: ...
-    @overload
-    def select(self, __cols: Union[List[Column], List[str]]) -> DataFrame: ...
-    @overload
-    def selectExpr(self, *expr: str) -> DataFrame: ...
-    @overload
-    def selectExpr(self, *expr: List[str]) -> DataFrame: ...
-    def filter(self, condition: ColumnOrName) -> DataFrame: ...
-    @overload
-    def groupBy(self, *cols: ColumnOrName) -> GroupedData: ...
-    @overload
-    def groupBy(self, __cols: Union[List[Column], List[str]]) -> GroupedData: 
...
-    @overload
-    def rollup(self, *cols: ColumnOrName) -> GroupedData: ...
-    @overload
-    def rollup(self, __cols: Union[List[Column], List[str]]) -> GroupedData: 
...
-    @overload
-    def cube(self, *cols: ColumnOrName) -> GroupedData: ...
-    @overload
-    def cube(self, __cols: Union[List[Column], List[str]]) -> GroupedData: ...
-    def agg(self, *exprs: Union[Column, Dict[str, str]]) -> DataFrame: ...
-    def observe(self, observation: Observation, *exprs: Column) -> DataFrame: 
...
-    def union(self, other: DataFrame) -> DataFrame: ...
-    def unionAll(self, other: DataFrame) -> DataFrame: ...
-    def unionByName(
-        self, other: DataFrame, allowMissingColumns: bool = ...
-    ) -> DataFrame: ...
-    def intersect(self, other: DataFrame) -> DataFrame: ...
-    def intersectAll(self, other: DataFrame) -> DataFrame: ...
-    def subtract(self, other: DataFrame) -> DataFrame: ...
-    def dropDuplicates(self, subset: Optional[List[str]] = ...) -> DataFrame: 
...
-    def dropna(
-        self,
-        how: str = ...,
-        thresh: Optional[int] = ...,
-        subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ...,
-    ) -> DataFrame: ...
-    @overload
-    def fillna(
-        self,
-        value: LiteralType,
-        subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ...,
-    ) -> DataFrame: ...
-    @overload
-    def fillna(self, value: Dict[str, LiteralType]) -> DataFrame: ...
-    @overload
-    def replace(
-        self,
-        to_replace: LiteralType,
-        value: OptionalPrimitiveType,
-        subset: Optional[List[str]] = ...,
-    ) -> DataFrame: ...
-    @overload
-    def replace(
-        self,
-        to_replace: List[LiteralType],
-        value: List[OptionalPrimitiveType],
-        subset: Optional[List[str]] = ...,
-    ) -> DataFrame: ...
-    @overload
-    def replace(
-        self,
-        to_replace: Dict[LiteralType, OptionalPrimitiveType],
-        subset: Optional[List[str]] = ...,
-    ) -> DataFrame: ...
-    @overload
-    def replace(
-        self,
-        to_replace: List[LiteralType],
-        value: OptionalPrimitiveType,
-        subset: Optional[List[str]] = ...,
-    ) -> DataFrame: ...
-    @overload
-    def approxQuantile(
-        self,
-        col: str,
-        probabilities: Union[List[float], Tuple[float]],
-        relativeError: float,
-    ) -> List[float]: ...
-    @overload
-    def approxQuantile(
-        self,
-        col: Union[List[str], Tuple[str]],
-        probabilities: Union[List[float], Tuple[float]],
-        relativeError: float,
-    ) -> List[List[float]]: ...
-    def corr(self, col1: str, col2: str, method: Optional[str] = ...) -> 
float: ...
-    def cov(self, col1: str, col2: str) -> float: ...
-    def crosstab(self, col1: str, col2: str) -> DataFrame: ...
-    def freqItems(
-        self, cols: Union[List[str], Tuple[str]], support: Optional[float] = 
...
-    ) -> DataFrame: ...
-    def withColumn(self, colName: str, col: Column) -> DataFrame: ...
-    def withColumnRenamed(self, existing: str, new: str) -> DataFrame: ...
-    @overload
-    def drop(self, cols: ColumnOrName) -> DataFrame: ...
-    @overload
-    def drop(self, *cols: str) -> DataFrame: ...
-    def toDF(self, *cols: ColumnOrName) -> DataFrame: ...
-    def transform(self, func: Callable[[DataFrame], DataFrame]) -> DataFrame: 
...
-    @overload
-    def groupby(self, *cols: ColumnOrName) -> GroupedData: ...
-    @overload
-    def groupby(self, __cols: Union[List[Column], List[str]]) -> GroupedData: 
...
-    def drop_duplicates(self, subset: Optional[List[str]] = ...) -> DataFrame: 
...
-    def where(self, condition: ColumnOrName) -> DataFrame: ...
-    def sameSemantics(self, other: DataFrame) -> bool: ...
-    def semanticHash(self) -> int: ...
-    def inputFiles(self) -> List[str]: ...
-    def writeTo(self, table: str) -> DataFrameWriterV2: ...
-    def to_pandas_on_spark(self, index_col: Optional[Union[str, List[str]]] = 
None) -> PandasOnSparkDataFrame: ...
-
-class DataFrameNaFunctions:
-    df: DataFrame
-    def __init__(self, df: DataFrame) -> None: ...
-    def drop(
-        self,
-        how: str = ...,
-        thresh: Optional[int] = ...,
-        subset: Optional[List[str]] = ...,
-    ) -> DataFrame: ...
-    @overload
-    def fill(
-        self, value: LiteralType, subset: Optional[List[str]] = ...
-    ) -> DataFrame: ...
-    @overload
-    def fill(self, value: Dict[str, LiteralType]) -> DataFrame: ...
-    @overload
-    def replace(
-        self,
-        to_replace: LiteralType,
-        value: OptionalPrimitiveType,
-        subset: Optional[List[str]] = ...,
-    ) -> DataFrame: ...
-    @overload
-    def replace(
-        self,
-        to_replace: List[LiteralType],
-        value: List[OptionalPrimitiveType],
-        subset: Optional[List[str]] = ...,
-    ) -> DataFrame: ...
-    @overload
-    def replace(
-        self,
-        to_replace: Dict[LiteralType, OptionalPrimitiveType],
-        subset: Optional[List[str]] = ...,
-    ) -> DataFrame: ...
-    @overload
-    def replace(
-        self,
-        to_replace: List[LiteralType],
-        value: OptionalPrimitiveType,
-        subset: Optional[List[str]] = ...,
-    ) -> DataFrame: ...
-
-class DataFrameStatFunctions:
-    df: DataFrame
-    def __init__(self, df: DataFrame) -> None: ...
-    @overload
-    def approxQuantile(
-        self,
-        col: str,
-        probabilities: Union[List[float], Tuple[float]],
-        relativeError: float,
-    ) -> List[float]: ...
-    @overload
-    def approxQuantile(
-        self,
-        col: Union[List[str], Tuple[str]],
-        probabilities: Union[List[float], Tuple[float]],
-        relativeError: float,
-    ) -> List[List[float]]: ...
-    def corr(self, col1: str, col2: str, method: Optional[str] = ...) -> 
float: ...
-    def cov(self, col1: str, col2: str) -> float: ...
-    def crosstab(self, col1: str, col2: str) -> DataFrame: ...
-    def freqItems(
-        self, cols: List[str], support: Optional[float] = ...
-    ) -> DataFrame: ...
-    def sampleBy(
-        self, col: str, fractions: Dict[Any, float], seed: Optional[int] = ...
-    ) -> DataFrame: ...
diff --git a/python/pyspark/sql/observation.py 
b/python/pyspark/sql/observation.py
index 3e8a0d1..48d8176 100644
--- a/python/pyspark/sql/observation.py
+++ b/python/pyspark/sql/observation.py
@@ -99,7 +99,7 @@ class Observation:
         assert all(isinstance(c, Column) for c in exprs), "all exprs should be 
Column"
         assert self._jo is None, "an Observation can be used with a DataFrame 
only once"
 
-        self._jvm = df._sc._jvm  # type: ignore[assignment]
+        self._jvm = df._sc._jvm  # type: ignore[assignment, attr-defined]
         cls = self._jvm.org.apache.spark.sql.Observation  # type: 
ignore[attr-defined]
         self._jo = cls(self._name) if self._name is not None else cls()
         observed_df = self._jo.on(  # type: ignore[attr-defined]

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to