This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a9b4c27  [SPARK-36846][PYTHON] Inline most of type hint files under 
pyspark/sql/pandas folder
a9b4c27 is described below

commit a9b4c27f12e308b3674e1d6faf6ff998c1aed146
Author: Takuya UESHIN <[email protected]>
AuthorDate: Wed Sep 29 09:25:18 2021 +0900

    [SPARK-36846][PYTHON] Inline most of type hint files under 
pyspark/sql/pandas folder
    
    ### What changes were proposed in this pull request?
    
    Inlines type hint files under `pyspark/sql/pandas` folder, except for 
`pyspark/sql/pandas/functions.pyi` and files under `pyspark/sql/pandas/_typing`.
    
    - Since the file contains a lot of overloads, we should revisit and manage 
it separately.
    - We can't inline files under `pyspark/sql/pandas/_typing` because it 
includes new syntax for type hints.
    
    ### Why are the changes needed?
    
    Currently there are type hint stub files (`*.pyi`) to show the expected 
types for functions, but we can also take advantage of static type checking 
within the functions by inlining the type hints.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #34101 from ueshin/issues/SPARK-36846/inline_typehints.
    
    Authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../pyspark/sql/pandas/_typing/protocols/frame.pyi |   5 +
 python/pyspark/sql/pandas/conversion.py            | 120 +++++++++++++++------
 python/pyspark/sql/pandas/conversion.pyi           |  59 ----------
 python/pyspark/sql/pandas/group_ops.py             |  46 +++++---
 python/pyspark/sql/pandas/group_ops.pyi            |  49 ---------
 python/pyspark/sql/pandas/map_ops.py               |  14 ++-
 python/pyspark/sql/pandas/map_ops.pyi              |  30 ------
 python/pyspark/sql/pandas/typehints.py             |  24 ++++-
 python/pyspark/sql/pandas/types.py                 |  42 +++++---
 python/pyspark/sql/pandas/utils.py                 |   4 +-
 10 files changed, 186 insertions(+), 207 deletions(-)

diff --git a/python/pyspark/sql/pandas/_typing/protocols/frame.pyi 
b/python/pyspark/sql/pandas/_typing/protocols/frame.pyi
index e219c1c..6f450df 100644
--- a/python/pyspark/sql/pandas/_typing/protocols/frame.pyi
+++ b/python/pyspark/sql/pandas/_typing/protocols/frame.pyi
@@ -35,6 +35,8 @@ Axis = Any
 Level = Any
 
 class DataFrameLike(Protocol):
+    columns: Axes
+    dtypes: List[Any]
     def __init__(
         self,
         data: Any = ...,
@@ -422,7 +424,10 @@ class DataFrameLike(Protocol):
         self, freq: Any = ..., axis: Any = ..., copy: Any = ...
     ) -> DataFrameLike: ...
     def isin(self, values: Any) -> DataFrameLike: ...
+    def copy(self) -> DataFrameLike: ...
     plot: Any = ...
     hist: Any = ...
     boxplot: Any = ...
     sparse: Any = ...
+    loc: Any = ...
+    iloc: Any = ...
diff --git a/python/pyspark/sql/pandas/conversion.py 
b/python/pyspark/sql/pandas/conversion.py
index 6761217..354d3a9 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -17,8 +17,9 @@
 import sys
 import warnings
 from collections import Counter
+from typing import List, Optional, Type, Union, no_type_check, overload, 
TYPE_CHECKING
 
-from pyspark.rdd import _load_from_socket
+from pyspark.rdd import _load_from_socket  # type: ignore[attr-defined]
 from pyspark.sql.pandas.serializers import ArrowCollectSerializer
 from pyspark.sql.types import IntegralType
 from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, 
FloatType, \
@@ -26,6 +27,13 @@ from pyspark.sql.types import ByteType, ShortType, 
IntegerType, LongType, FloatT
 from pyspark.sql.utils import is_timestamp_ntz_preferred
 from pyspark.traceback_utils import SCCallSiteSync
 
+if TYPE_CHECKING:
+    import numpy as np
+    import pyarrow as pa
+
+    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+    from pyspark.sql import DataFrame
+
 
 class PandasConversionMixin(object):
     """
@@ -33,7 +41,7 @@ class PandasConversionMixin(object):
     can use this class.
     """
 
-    def toPandas(self):
+    def toPandas(self) -> "PandasDataFrameLike":
         """
         Returns the contents of this :class:`DataFrame` as Pandas 
``pandas.DataFrame``.
 
@@ -65,9 +73,9 @@ class PandasConversionMixin(object):
         import numpy as np
         import pandas as pd
 
-        timezone = self.sql_ctx._conf.sessionLocalTimeZone()
+        timezone = self.sql_ctx._conf.sessionLocalTimeZone()  # type: 
ignore[attr-defined]
 
-        if self.sql_ctx._conf.arrowPySparkEnabled():
+        if self.sql_ctx._conf.arrowPySparkEnabled():  # type: 
ignore[attr-defined]
             use_arrow = True
             try:
                 from pyspark.sql.pandas.types import to_arrow_schema
@@ -77,7 +85,7 @@ class PandasConversionMixin(object):
                 to_arrow_schema(self.schema)
             except Exception as e:
 
-                if self.sql_ctx._conf.arrowPySparkFallbackEnabled():
+                if self.sql_ctx._conf.arrowPySparkFallbackEnabled():  # type: 
ignore[attr-defined]
                     msg = (
                         "toPandas attempted Arrow optimization because "
                         "'spark.sql.execution.arrow.pyspark.enabled' is set to 
true; however, "
@@ -106,7 +114,10 @@ class PandasConversionMixin(object):
                     import pyarrow
                     # Rename columns to avoid duplicated column names.
                     tmp_column_names = ['col_{}'.format(i) for i in 
range(len(self.columns))]
-                    self_destruct = 
self.sql_ctx._conf.arrowPySparkSelfDestructEnabled()
+                    self_destruct = (
+                        self.sql_ctx._conf  # type: ignore[attr-defined]
+                            .arrowPySparkSelfDestructEnabled()
+                    )
                     batches = self.toDF(*tmp_column_names)._collect_as_arrow(
                         split_batches=self_destruct)
                     if len(batches) > 0:
@@ -158,7 +169,7 @@ class PandasConversionMixin(object):
         pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
         column_counter = Counter(self.columns)
 
-        dtype = [None] * len(self.schema)
+        dtype = [None] * len(self.schema)  # type: List[Optional[Type]]
         for fieldIdx, field in enumerate(self.schema):
             # For duplicate column name, we use `iloc` to access it.
             if column_counter[field.name] > 1:
@@ -179,7 +190,7 @@ class PandasConversionMixin(object):
             if isinstance(field.dataType, IntegralType) and 
pandas_col.isnull().any():
                 dtype[fieldIdx] = np.float64
             if isinstance(field.dataType, BooleanType) and 
pandas_col.isnull().any():
-                dtype[fieldIdx] = np.object
+                dtype[fieldIdx] = np.object  # type: ignore[attr-defined]
 
         df = pd.DataFrame()
         for index, t in enumerate(dtype):
@@ -216,7 +227,7 @@ class PandasConversionMixin(object):
             return pdf
 
     @staticmethod
-    def _to_corrected_pandas_type(dt):
+    def _to_corrected_pandas_type(dt: DataType) -> Optional[Type]:
         """
         When converting Spark SQL records to Pandas :class:`DataFrame`, the 
inferred data type
         may be wrong. This method gets the corrected data type for Pandas if 
that type may be
@@ -236,7 +247,7 @@ class PandasConversionMixin(object):
         elif type(dt) == DoubleType:
             return np.float64
         elif type(dt) == BooleanType:
-            return np.bool
+            return np.bool  # type: ignore[attr-defined]
         elif type(dt) == TimestampType:
             return np.datetime64
         elif type(dt) == TimestampNTZType:
@@ -244,7 +255,7 @@ class PandasConversionMixin(object):
         else:
             return None
 
-    def _collect_as_arrow(self, split_batches=False):
+    def _collect_as_arrow(self, split_batches: bool = False) -> 
List["pa.RecordBatch"]:
         """
         Returns all records as a list of ArrowRecordBatches, pyarrow must be 
installed
         and available on driver and worker Python environments.
@@ -260,7 +271,9 @@ class PandasConversionMixin(object):
         assert isinstance(self, DataFrame)
 
         with SCCallSiteSync(self._sc):
-            port, auth_secret, jsocket_auth_server = 
self._jdf.collectAsArrowToPython()
+            port, auth_secret, jsocket_auth_server = (
+                self._jdf.collectAsArrowToPython()  # type: ignore[operator]
+            )
 
         # Collect list of un-ordered batches where last element is a list of 
correct order indices
         try:
@@ -301,7 +314,29 @@ class SparkConversionMixin(object):
     Min-in for the conversion from pandas to Spark. Currently, only 
:class:`SparkSession`
     can use this class.
     """
-    def createDataFrame(self, data, schema=None, samplingRatio=None, 
verifySchema=True):
+
+    @overload
+    def createDataFrame(
+        self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ...
+    ) -> "DataFrame":
+        ...
+
+    @overload
+    def createDataFrame(
+        self,
+        data: "PandasDataFrameLike",
+        schema: Union[StructType, str],
+        verifySchema: bool = ...,
+    ) -> "DataFrame":
+        ...
+
+    def createDataFrame(  # type: ignore[misc]
+        self,
+        data: "PandasDataFrameLike",
+        schema: Optional[Union[StructType, List[str]]] = None,
+        samplingRatio: Optional[float] = None,
+        verifySchema: bool = True
+    ) -> "DataFrame":
         from pyspark.sql import SparkSession
 
         assert isinstance(self, SparkSession)
@@ -309,19 +344,20 @@ class SparkConversionMixin(object):
         from pyspark.sql.pandas.utils import require_minimum_pandas_version
         require_minimum_pandas_version()
 
-        timezone = self._wrapped._conf.sessionLocalTimeZone()
+        timezone = self._wrapped._conf.sessionLocalTimeZone()  # type: 
ignore[attr-defined]
 
         # If no schema supplied by user then get the names of columns only
         if schema is None:
-            schema = [str(x) if not isinstance(x, str) else
-                      (x.encode('utf-8') if not isinstance(x, str) else x)
-                      for x in data.columns]
+            schema = [str(x) if not isinstance(x, str) else x for x in 
data.columns]
 
-        if self._wrapped._conf.arrowPySparkEnabled() and len(data) > 0:
+        if (
+            self._wrapped._conf.arrowPySparkEnabled()  # type: 
ignore[attr-defined]
+            and len(data) > 0
+        ):
             try:
                 return self._create_from_pandas_with_arrow(data, schema, 
timezone)
             except Exception as e:
-                if self._wrapped._conf.arrowPySparkFallbackEnabled():
+                if self._wrapped._conf.arrowPySparkFallbackEnabled():  # type: 
ignore[attr-defined]
                     msg = (
                         "createDataFrame attempted Arrow optimization because "
                         "'spark.sql.execution.arrow.pyspark.enabled' is set to 
true; however, "
@@ -339,10 +375,17 @@ class SparkConversionMixin(object):
                         "has been set to false.\n  %s" % str(e))
                     warnings.warn(msg)
                     raise
-        data = self._convert_from_pandas(data, schema, timezone)
-        return self._create_dataframe(data, schema, samplingRatio, 
verifySchema)
-
-    def _convert_from_pandas(self, pdf, schema, timezone):
+        converted_data = self._convert_from_pandas(data, schema, timezone)
+        return self._create_dataframe(  # type: ignore[attr-defined]
+            converted_data, schema, samplingRatio, verifySchema
+        )
+
+    def _convert_from_pandas(
+        self,
+        pdf: "PandasDataFrameLike",
+        schema: Union[StructType, str, List[str]],
+        timezone: str
+    ) -> List:
         """
          Convert a pandas.DataFrame to list of records that can be used to 
make a DataFrame
 
@@ -398,7 +441,7 @@ class SparkConversionMixin(object):
         # Convert list of numpy records to python lists
         return [r.tolist() for r in np_records]
 
-    def _get_numpy_record_dtype(self, rec):
+    def _get_numpy_record_dtype(self, rec: "np.recarray") -> 
Optional["np.dtype"]:
         """
         Used when converting a pandas.DataFrame to Spark using to_records(), 
this will correct
         the dtypes of fields in a record so they can be properly loaded into 
Spark.
@@ -429,7 +472,12 @@ class SparkConversionMixin(object):
             record_type_list.append((str(col_names[i]), curr_type))
         return np.dtype(record_type_list) if has_rec_fix else None
 
-    def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
+    def _create_from_pandas_with_arrow(
+        self,
+        pdf: "PandasDataFrameLike",
+        schema: Union[StructType, List[str]],
+        timezone: str
+    ) -> "DataFrame":
         """
         Create a DataFrame from a given pandas.DataFrame by slicing it into 
partitions, converting
         to Arrow data, then sending to the JVM to parallelize. If a schema is 
passed in, the
@@ -483,27 +531,35 @@ class SparkConversionMixin(object):
         arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), 
arrow_types)]
                       for pdf_slice in pdf_slices]
 
-        jsqlContext = self._wrapped._jsqlContext
+        jsqlContext = self._wrapped._jsqlContext  # type: ignore[attr-defined]
 
-        safecheck = self._wrapped._conf.arrowSafeTypeConversion()
+        safecheck = self._wrapped._conf.arrowSafeTypeConversion()  # type: 
ignore[attr-defined]
         col_by_name = True  # col by name only applies to StructType columns, 
can't happen here
         ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
 
+        @no_type_check
         def reader_func(temp_filename):
             return 
self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
 
+        @no_type_check
         def create_RDD_server():
             return self._jvm.ArrowRDDServer(jsqlContext)
 
         # Create Spark DataFrame from Arrow stream file, using one batch per 
partition
-        jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, 
create_RDD_server)
-        jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), 
jsqlContext)
-        df = DataFrame(jdf, self._wrapped)
-        df._schema = schema
+        jrdd = (
+            self._sc  # type: ignore[attr-defined]
+                ._serialize_to_jvm(arrow_data, ser, reader_func, 
create_RDD_server)
+        )
+        jdf = (
+            self._jvm  # type: ignore[attr-defined]
+                .PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
+        )
+        df = DataFrame(jdf, self._wrapped)  # type: ignore[attr-defined]
+        df._schema = schema  # type: ignore[attr-defined]
         return df
 
 
-def _test():
+def _test() -> None:
     import doctest
     from pyspark.sql import SparkSession
     import pyspark.sql.pandas.conversion
diff --git a/python/pyspark/sql/pandas/conversion.pyi 
b/python/pyspark/sql/pandas/conversion.pyi
deleted file mode 100644
index 87637722..0000000
--- a/python/pyspark/sql/pandas/conversion.pyi
+++ /dev/null
@@ -1,59 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import overload
-from typing import Optional, Union
-
-from pyspark.sql.pandas._typing import DataFrameLike
-from pyspark import since as since  # noqa: F401
-from pyspark.rdd import RDD  # noqa: F401
-import pyspark.sql.dataframe
-from pyspark.sql.pandas.serializers import (  # noqa: F401
-    ArrowCollectSerializer as ArrowCollectSerializer,
-)
-from pyspark.sql.types import (  # noqa: F401
-    BooleanType as BooleanType,
-    ByteType as ByteType,
-    DataType as DataType,
-    DoubleType as DoubleType,
-    FloatType as FloatType,
-    IntegerType as IntegerType,
-    IntegralType as IntegralType,
-    LongType as LongType,
-    ShortType as ShortType,
-    StructType as StructType,
-    TimestampType as TimestampType,
-    TimestampNTZType as TimestampNTZType,
-)
-from pyspark.traceback_utils import SCCallSiteSync as SCCallSiteSync  # noqa: 
F401
-
-class PandasConversionMixin:
-    def toPandas(self) -> DataFrameLike: ...
-
-class SparkConversionMixin:
-    @overload
-    def createDataFrame(
-        self, data: DataFrameLike, samplingRatio: Optional[float] = ...
-    ) -> pyspark.sql.dataframe.DataFrame: ...
-    @overload
-    def createDataFrame(
-        self,
-        data: DataFrameLike,
-        schema: Union[StructType, str],
-        verifySchema: bool = ...,
-    ) -> pyspark.sql.dataframe.DataFrame: ...
diff --git a/python/pyspark/sql/pandas/group_ops.py 
b/python/pyspark/sql/pandas/group_ops.py
index 8d4f67e..84db18f 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -15,11 +15,21 @@
 # limitations under the License.
 #
 import sys
+from typing import List, Union, TYPE_CHECKING
 import warnings
 
 from pyspark.rdd import PythonEvalType
 from pyspark.sql.column import Column
 from pyspark.sql.dataframe import DataFrame
+from pyspark.sql.types import StructType
+
+if TYPE_CHECKING:
+    from pyspark.sql.pandas._typing import (
+        GroupedMapPandasUserDefinedFunction,
+        PandasGroupedMapFunction,
+        PandasCogroupedMapFunction,
+    )
+    from pyspark.sql.group import GroupedData
 
 
 class PandasGroupedOpsMixin(object):
@@ -28,7 +38,7 @@ class PandasGroupedOpsMixin(object):
     can use this class.
     """
 
-    def apply(self, udf):
+    def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> DataFrame:
         """
         It is an alias of :meth:`pyspark.sql.GroupedData.applyInPandas`; 
however, it takes a
         :meth:`pyspark.sql.functions.pandas_udf` whereas
@@ -73,8 +83,10 @@ class PandasGroupedOpsMixin(object):
         pyspark.sql.functions.pandas_udf
         """
         # Columns are special because hasattr always return True
-        if isinstance(udf, Column) or not hasattr(udf, 'func') \
-                or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
+        if isinstance(udf, Column) or not hasattr(udf, 'func') or (
+            udf.evalType  # type: ignore[attr-defined]
+            != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+        ):
             raise ValueError("Invalid udf: the udf argument must be a 
pandas_udf of type "
                              "GROUPED_MAP.")
 
@@ -83,9 +95,11 @@ class PandasGroupedOpsMixin(object):
             "API. This API will be deprecated in the future releases. See 
SPARK-28264 for "
             "more details.", UserWarning)
 
-        return self.applyInPandas(udf.func, schema=udf.returnType)
+        return self.applyInPandas(udf.func, schema=udf.returnType)  # type: 
ignore[attr-defined]
 
-    def applyInPandas(self, func, schema):
+    def applyInPandas(
+        self, func: "PandasGroupedMapFunction", schema: Union[StructType, str]
+    ) -> DataFrame:
         """
         Maps each group of the current :class:`DataFrame` using a pandas udf 
and returns the result
         as a `DataFrame`.
@@ -197,12 +211,12 @@ class PandasGroupedOpsMixin(object):
 
         udf = pandas_udf(
             func, returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
-        df = self._df
+        df = self._df  # type: ignore[attr-defined]
         udf_column = udf(*[df[col] for col in df.columns])
-        jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
+        jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())  # type: 
ignore[attr-defined]
         return DataFrame(jdf, self.sql_ctx)
 
-    def cogroup(self, other):
+    def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
         """
         Cogroups this group with another group so that we can run cogrouped 
operations.
 
@@ -229,12 +243,14 @@ class PandasCogroupedOps(object):
     This API is experimental.
     """
 
-    def __init__(self, gd1, gd2):
+    def __init__(self, gd1: "GroupedData", gd2: "GroupedData"):
         self._gd1 = gd1
         self._gd2 = gd2
         self.sql_ctx = gd1.sql_ctx
 
-    def applyInPandas(self, func, schema):
+    def applyInPandas(
+        self, func: "PandasCogroupedMapFunction", schema: Union[StructType, 
str]
+    ) -> DataFrame:
         """
         Applies a function to each cogroup using pandas and returns the result
         as a `DataFrame`.
@@ -329,16 +345,18 @@ class PandasCogroupedOps(object):
             func, returnType=schema, 
functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
         all_cols = self._extract_cols(self._gd1) + 
self._extract_cols(self._gd2)
         udf_column = udf(*all_cols)
-        jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, 
udf_column._jc.expr())
+        jdf = self._gd1._jgd.flatMapCoGroupsInPandas(  # type: 
ignore[attr-defined]
+            self._gd2._jgd, udf_column._jc.expr()  # type: ignore[attr-defined]
+        )
         return DataFrame(jdf, self.sql_ctx)
 
     @staticmethod
-    def _extract_cols(gd):
-        df = gd._df
+    def _extract_cols(gd: "GroupedData") -> List[Column]:
+        df = gd._df  # type: ignore[attr-defined]
         return [df[col] for col in df.columns]
 
 
-def _test():
+def _test() -> None:
     import doctest
     from pyspark.sql import SparkSession
     import pyspark.sql.pandas.group_ops
diff --git a/python/pyspark/sql/pandas/group_ops.pyi 
b/python/pyspark/sql/pandas/group_ops.pyi
deleted file mode 100644
index 2c543e0..0000000
--- a/python/pyspark/sql/pandas/group_ops.pyi
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import Union
-
-from pyspark.sql.pandas._typing import (
-    GroupedMapPandasUserDefinedFunction,
-    PandasGroupedMapFunction,
-    PandasCogroupedMapFunction,
-)
-
-from pyspark import since as since  # noqa: F401
-from pyspark.rdd import PythonEvalType as PythonEvalType  # noqa: F401
-from pyspark.sql.column import Column as Column  # noqa: F401
-from pyspark.sql.context import SQLContext
-import pyspark.sql.group
-from pyspark.sql.dataframe import DataFrame as DataFrame
-from pyspark.sql.types import StructType
-
-class PandasGroupedOpsMixin:
-    def cogroup(self, other: pyspark.sql.group.GroupedData) -> 
PandasCogroupedOps: ...
-    def apply(self, udf: GroupedMapPandasUserDefinedFunction) -> DataFrame: ...
-    def applyInPandas(
-        self, func: PandasGroupedMapFunction, schema: Union[StructType, str]
-    ) -> DataFrame: ...
-
-class PandasCogroupedOps:
-    sql_ctx: SQLContext
-    def __init__(
-        self, gd1: pyspark.sql.group.GroupedData, gd2: 
pyspark.sql.group.GroupedData
-    ) -> None: ...
-    def applyInPandas(
-        self, func: PandasCogroupedMapFunction, schema: Union[StructType, str]
-    ) -> DataFrame: ...
diff --git a/python/pyspark/sql/pandas/map_ops.py 
b/python/pyspark/sql/pandas/map_ops.py
index 63fe371..21f81e3 100644
--- a/python/pyspark/sql/pandas/map_ops.py
+++ b/python/pyspark/sql/pandas/map_ops.py
@@ -15,8 +15,14 @@
 # limitations under the License.
 #
 import sys
+from typing import Union, TYPE_CHECKING
 
 from pyspark.rdd import PythonEvalType
+from pyspark.sql.types import StructType
+
+if TYPE_CHECKING:
+    from pyspark.sql.dataframe import DataFrame
+    from pyspark.sql.pandas._typing import PandasMapIterFunction
 
 
 class PandasMapOpsMixin(object):
@@ -25,7 +31,9 @@ class PandasMapOpsMixin(object):
     can use this class.
     """
 
-    def mapInPandas(self, func, schema):
+    def mapInPandas(
+        self, func: "PandasMapIterFunction", schema: Union[StructType, str]
+    ) -> "DataFrame":
         """
         Maps an iterator of batches in the current :class:`DataFrame` using a 
Python native
         function that takes and outputs a pandas DataFrame, and returns the 
result as a
@@ -79,11 +87,11 @@ class PandasMapOpsMixin(object):
         udf = pandas_udf(
             func, returnType=schema, 
functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
         udf_column = udf(*[self[col] for col in self.columns])
-        jdf = self._jdf.mapInPandas(udf_column._jc.expr())
+        jdf = self._jdf.mapInPandas(udf_column._jc.expr())  # type: 
ignore[operator]
         return DataFrame(jdf, self.sql_ctx)
 
 
-def _test():
+def _test() -> None:
     import doctest
     from pyspark.sql import SparkSession
     import pyspark.sql.pandas.map_ops
diff --git a/python/pyspark/sql/pandas/map_ops.pyi 
b/python/pyspark/sql/pandas/map_ops.pyi
deleted file mode 100644
index cab8852..0000000
--- a/python/pyspark/sql/pandas/map_ops.pyi
+++ /dev/null
@@ -1,30 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import Union
-
-from pyspark.sql.pandas._typing import PandasMapIterFunction
-from pyspark import since as since  # noqa: F401
-from pyspark.rdd import PythonEvalType as PythonEvalType  # noqa: F401
-from pyspark.sql.types import StructType
-import pyspark.sql.dataframe
-
-class PandasMapOpsMixin:
-    def mapInPandas(
-        self, udf: PandasMapIterFunction, schema: Union[StructType, str]
-    ) -> pyspark.sql.dataframe.DataFrame: ...
diff --git a/python/pyspark/sql/pandas/typehints.py 
b/python/pyspark/sql/pandas/typehints.py
index e696f67..1990dd2 100644
--- a/python/pyspark/sql/pandas/typehints.py
+++ b/python/pyspark/sql/pandas/typehints.py
@@ -14,10 +14,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+from inspect import Signature
+from typing import Any, Callable, Optional, Union, TYPE_CHECKING
+
 from pyspark.sql.pandas.utils import require_minimum_pandas_version
 
+if TYPE_CHECKING:
+    from pyspark.sql.pandas._typing import (
+        PandasScalarUDFType, PandasScalarIterUDFType, PandasGroupedAggUDFType
+    )
+
 
-def infer_eval_type(sig):
+def infer_eval_type(
+    sig: Signature
+) -> Union["PandasScalarUDFType", "PandasScalarIterUDFType", 
"PandasGroupedAggUDFType"]:
     """
     Infers the evaluation type in :class:`pyspark.rdd.PythonEvalType` from
     :class:`inspect.Signature` instance.
@@ -117,7 +127,9 @@ def infer_eval_type(sig):
         raise NotImplementedError("Unsupported signature: %s." % sig)
 
 
-def check_tuple_annotation(annotation, parameter_check_func=None):
+def check_tuple_annotation(
+    annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] = 
None
+) -> bool:
     # Python 3.6 has `__name__`. Python 3.7 and 3.8 have `_name`.
     # Check if the name is Tuple first. After that, check the generic types.
     name = getattr(annotation, "_name", getattr(annotation, "__name__", None))
@@ -125,13 +137,17 @@ def check_tuple_annotation(annotation, 
parameter_check_func=None):
         parameter_check_func is None or all(map(parameter_check_func, 
annotation.__args__)))
 
 
-def check_iterator_annotation(annotation, parameter_check_func=None):
+def check_iterator_annotation(
+    annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] = 
None
+) -> bool:
     name = getattr(annotation, "_name", getattr(annotation, "__name__", None))
     return name == "Iterator" and (
         parameter_check_func is None or all(map(parameter_check_func, 
annotation.__args__)))
 
 
-def check_union_annotation(annotation, parameter_check_func=None):
+def check_union_annotation(
+    annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] = 
None
+) -> bool:
     import typing
 
     # Note that we cannot rely on '__origin__' in other type hints as it has 
changed from version
diff --git a/python/pyspark/sql/pandas/types.py 
b/python/pyspark/sql/pandas/types.py
index ceb71a3..44253bf 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -19,13 +19,19 @@
 Type-specific codes between pandas and PyArrow. Also contains some utils to 
correct
 pandas instances during the type conversion.
 """
+from typing import Optional, TYPE_CHECKING
 
 from pyspark.sql.types import BooleanType, ByteType, ShortType, IntegerType, 
LongType, \
     FloatType, DoubleType, DecimalType, StringType, BinaryType, DateType, 
TimestampType, \
-    TimestampNTZType, ArrayType, MapType, StructType, StructField, NullType
+    TimestampNTZType, ArrayType, MapType, StructType, StructField, NullType, 
DataType
+
+if TYPE_CHECKING:
+    import pyarrow as pa
+
+    from pyspark.sql.pandas._typing import SeriesLike as PandasSeriesLike
 
 
-def to_arrow_type(dt):
+def to_arrow_type(dt: DataType) -> "pa.DataType":
     """ Convert Spark data type to pyarrow type
     """
     from distutils.version import LooseVersion
@@ -81,7 +87,7 @@ def to_arrow_type(dt):
     return arrow_type
 
 
-def to_arrow_schema(schema):
+def to_arrow_schema(schema: StructType) -> "pa.Schema":
     """ Convert a schema from Spark to Arrow
     """
     import pyarrow as pa
@@ -90,14 +96,14 @@ def to_arrow_schema(schema):
     return pa.schema(fields)
 
 
-def from_arrow_type(at, prefer_timestamp_ntz=False):
+def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> 
DataType:
     """ Convert pyarrow type to Spark data type.
     """
     from distutils.version import LooseVersion
     import pyarrow as pa
     import pyarrow.types as types
     if types.is_boolean(at):
-        spark_type = BooleanType()
+        spark_type = BooleanType()  # type: DataType
     elif types.is_int8(at):
         spark_type = ByteType()
     elif types.is_int16(at):
@@ -147,7 +153,7 @@ def from_arrow_type(at, prefer_timestamp_ntz=False):
     return spark_type
 
 
-def from_arrow_schema(arrow_schema):
+def from_arrow_schema(arrow_schema: "pa.Schema") -> StructType:
     """ Convert schema from Arrow to Spark.
     """
     return StructType(
@@ -155,7 +161,7 @@ def from_arrow_schema(arrow_schema):
          for field in arrow_schema])
 
 
-def _get_local_timezone():
+def _get_local_timezone() -> str:
     """ Get local timezone using pytz with environment variable, or dateutil.
 
     If there is a 'TZ' environment variable, pass it to pandas to use pytz and 
use it as timezone
@@ -170,7 +176,7 @@ def _get_local_timezone():
     return os.environ.get('TZ', 'dateutil/:')
 
 
-def _check_series_localize_timestamps(s, timezone):
+def _check_series_localize_timestamps(s: "PandasSeriesLike", timezone: str) -> 
"PandasSeriesLike":
     """
     Convert timezone aware timestamps to timezone-naive in the specified 
timezone or local timezone.
 
@@ -200,7 +206,9 @@ def _check_series_localize_timestamps(s, timezone):
         return s
 
 
-def _check_series_convert_timestamps_internal(s, timezone):
+def _check_series_convert_timestamps_internal(
+    s: "PandasSeriesLike", timezone: str
+) -> "PandasSeriesLike":
     """
     Convert a tz-naive timestamp in the specified timezone or local timezone 
to UTC normalized for
     Spark internal storage
@@ -260,7 +268,9 @@ def _check_series_convert_timestamps_internal(s, timezone):
         return s
 
 
-def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
+def _check_series_convert_timestamps_localize(
+    s: "PandasSeriesLike", from_timezone: Optional[str], to_timezone: 
Optional[str]
+) -> "PandasSeriesLike":
     """
     Convert timestamp to timezone-naive in the specified timezone or local 
timezone
 
@@ -296,7 +306,9 @@ def _check_series_convert_timestamps_localize(s, 
from_timezone, to_timezone):
         return s
 
 
-def _check_series_convert_timestamps_local_tz(s, timezone):
+def _check_series_convert_timestamps_local_tz(
+    s: "PandasSeriesLike", timezone: str
+) -> "PandasSeriesLike":
     """
     Convert timestamp to timezone-naive in the specified timezone or local 
timezone
 
@@ -314,7 +326,9 @@ def _check_series_convert_timestamps_local_tz(s, timezone):
     return _check_series_convert_timestamps_localize(s, None, timezone)
 
 
-def _check_series_convert_timestamps_tz_local(s, timezone):
+def _check_series_convert_timestamps_tz_local(
+    s: "PandasSeriesLike", timezone: str
+) -> "PandasSeriesLike":
     """
     Convert timestamp to timezone-naive in the specified timezone or local 
timezone
 
@@ -332,7 +346,7 @@ def _check_series_convert_timestamps_tz_local(s, timezone):
     return _check_series_convert_timestamps_localize(s, timezone, None)
 
 
-def _convert_map_items_to_dict(s):
+def _convert_map_items_to_dict(s: "PandasSeriesLike") -> "PandasSeriesLike":
     """
     Convert a series with items as list of (key, value), as made from an Arrow 
column of map type,
     to dict for compatibility with non-arrow MapType columns.
@@ -342,7 +356,7 @@ def _convert_map_items_to_dict(s):
     return s.apply(lambda m: None if m is None else {k: v for k, v in m})
 
 
-def _convert_dict_to_map_items(s):
+def _convert_dict_to_map_items(s: "PandasSeriesLike") -> "PandasSeriesLike":
     """
     Convert a series of dictionaries to list of (key, value) pairs to match 
expected data
     for Arrow column of map type.
diff --git a/python/pyspark/sql/pandas/utils.py 
b/python/pyspark/sql/pandas/utils.py
index b22603a..4144929 100644
--- a/python/pyspark/sql/pandas/utils.py
+++ b/python/pyspark/sql/pandas/utils.py
@@ -16,7 +16,7 @@
 #
 
 
-def require_minimum_pandas_version():
+def require_minimum_pandas_version() -> None:
     """ Raise ImportError if minimum version of Pandas is not installed
     """
     # TODO(HyukjinKwon): Relocate and deduplicate the version specification.
@@ -37,7 +37,7 @@ def require_minimum_pandas_version():
                           "your version was %s." % (minimum_pandas_version, 
pandas.__version__))
 
 
-def require_minimum_pyarrow_version():
+def require_minimum_pyarrow_version() -> None:
     """ Raise ImportError if minimum version of pyarrow is not installed
     """
     # TODO(HyukjinKwon): Relocate and deduplicate the version specification.

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to