This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a9b4c27 [SPARK-36846][PYTHON] Inline most of type hint files under
pyspark/sql/pandas folder
a9b4c27 is described below
commit a9b4c27f12e308b3674e1d6faf6ff998c1aed146
Author: Takuya UESHIN <[email protected]>
AuthorDate: Wed Sep 29 09:25:18 2021 +0900
[SPARK-36846][PYTHON] Inline most of type hint files under
pyspark/sql/pandas folder
### What changes were proposed in this pull request?
Inlines type hint files under `pyspark/sql/pandas` folder, except for
`pyspark/sql/pandas/functions.pyi` and files under `pyspark/sql/pandas/_typing`.
- Since the file contains a lot of overloads, we should revisit and manage
it separately.
- We can't inline files under `pyspark/sql/pandas/_typing` because it
includes new syntax for type hints.
### Why are the changes needed?
Currently there are type hint stub files (`*.pyi`) to show the expected
types for functions, but we can also take advantage of static type checking
within the functions by inlining the type hints.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
Closes #34101 from ueshin/issues/SPARK-36846/inline_typehints.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../pyspark/sql/pandas/_typing/protocols/frame.pyi | 5 +
python/pyspark/sql/pandas/conversion.py | 120 +++++++++++++++------
python/pyspark/sql/pandas/conversion.pyi | 59 ----------
python/pyspark/sql/pandas/group_ops.py | 46 +++++---
python/pyspark/sql/pandas/group_ops.pyi | 49 ---------
python/pyspark/sql/pandas/map_ops.py | 14 ++-
python/pyspark/sql/pandas/map_ops.pyi | 30 ------
python/pyspark/sql/pandas/typehints.py | 24 ++++-
python/pyspark/sql/pandas/types.py | 42 +++++---
python/pyspark/sql/pandas/utils.py | 4 +-
10 files changed, 186 insertions(+), 207 deletions(-)
diff --git a/python/pyspark/sql/pandas/_typing/protocols/frame.pyi
b/python/pyspark/sql/pandas/_typing/protocols/frame.pyi
index e219c1c..6f450df 100644
--- a/python/pyspark/sql/pandas/_typing/protocols/frame.pyi
+++ b/python/pyspark/sql/pandas/_typing/protocols/frame.pyi
@@ -35,6 +35,8 @@ Axis = Any
Level = Any
class DataFrameLike(Protocol):
+ columns: Axes
+ dtypes: List[Any]
def __init__(
self,
data: Any = ...,
@@ -422,7 +424,10 @@ class DataFrameLike(Protocol):
self, freq: Any = ..., axis: Any = ..., copy: Any = ...
) -> DataFrameLike: ...
def isin(self, values: Any) -> DataFrameLike: ...
+ def copy(self) -> DataFrameLike: ...
plot: Any = ...
hist: Any = ...
boxplot: Any = ...
sparse: Any = ...
+ loc: Any = ...
+ iloc: Any = ...
diff --git a/python/pyspark/sql/pandas/conversion.py
b/python/pyspark/sql/pandas/conversion.py
index 6761217..354d3a9 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -17,8 +17,9 @@
import sys
import warnings
from collections import Counter
+from typing import List, Optional, Type, Union, no_type_check, overload,
TYPE_CHECKING
-from pyspark.rdd import _load_from_socket
+from pyspark.rdd import _load_from_socket # type: ignore[attr-defined]
from pyspark.sql.pandas.serializers import ArrowCollectSerializer
from pyspark.sql.types import IntegralType
from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType,
FloatType, \
@@ -26,6 +27,13 @@ from pyspark.sql.types import ByteType, ShortType,
IntegerType, LongType, FloatT
from pyspark.sql.utils import is_timestamp_ntz_preferred
from pyspark.traceback_utils import SCCallSiteSync
+if TYPE_CHECKING:
+ import numpy as np
+ import pyarrow as pa
+
+ from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+ from pyspark.sql import DataFrame
+
class PandasConversionMixin(object):
"""
@@ -33,7 +41,7 @@ class PandasConversionMixin(object):
can use this class.
"""
- def toPandas(self):
+ def toPandas(self) -> "PandasDataFrameLike":
"""
Returns the contents of this :class:`DataFrame` as Pandas
``pandas.DataFrame``.
@@ -65,9 +73,9 @@ class PandasConversionMixin(object):
import numpy as np
import pandas as pd
- timezone = self.sql_ctx._conf.sessionLocalTimeZone()
+ timezone = self.sql_ctx._conf.sessionLocalTimeZone() # type:
ignore[attr-defined]
- if self.sql_ctx._conf.arrowPySparkEnabled():
+ if self.sql_ctx._conf.arrowPySparkEnabled(): # type:
ignore[attr-defined]
use_arrow = True
try:
from pyspark.sql.pandas.types import to_arrow_schema
@@ -77,7 +85,7 @@ class PandasConversionMixin(object):
to_arrow_schema(self.schema)
except Exception as e:
- if self.sql_ctx._conf.arrowPySparkFallbackEnabled():
+ if self.sql_ctx._conf.arrowPySparkFallbackEnabled(): # type:
ignore[attr-defined]
msg = (
"toPandas attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to
true; however, "
@@ -106,7 +114,10 @@ class PandasConversionMixin(object):
import pyarrow
# Rename columns to avoid duplicated column names.
tmp_column_names = ['col_{}'.format(i) for i in
range(len(self.columns))]
- self_destruct =
self.sql_ctx._conf.arrowPySparkSelfDestructEnabled()
+ self_destruct = (
+ self.sql_ctx._conf # type: ignore[attr-defined]
+ .arrowPySparkSelfDestructEnabled()
+ )
batches = self.toDF(*tmp_column_names)._collect_as_arrow(
split_batches=self_destruct)
if len(batches) > 0:
@@ -158,7 +169,7 @@ class PandasConversionMixin(object):
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
column_counter = Counter(self.columns)
- dtype = [None] * len(self.schema)
+ dtype = [None] * len(self.schema) # type: List[Optional[Type]]
for fieldIdx, field in enumerate(self.schema):
# For duplicate column name, we use `iloc` to access it.
if column_counter[field.name] > 1:
@@ -179,7 +190,7 @@ class PandasConversionMixin(object):
if isinstance(field.dataType, IntegralType) and
pandas_col.isnull().any():
dtype[fieldIdx] = np.float64
if isinstance(field.dataType, BooleanType) and
pandas_col.isnull().any():
- dtype[fieldIdx] = np.object
+ dtype[fieldIdx] = np.object # type: ignore[attr-defined]
df = pd.DataFrame()
for index, t in enumerate(dtype):
@@ -216,7 +227,7 @@ class PandasConversionMixin(object):
return pdf
@staticmethod
- def _to_corrected_pandas_type(dt):
+ def _to_corrected_pandas_type(dt: DataType) -> Optional[Type]:
"""
When converting Spark SQL records to Pandas :class:`DataFrame`, the
inferred data type
may be wrong. This method gets the corrected data type for Pandas if
that type may be
@@ -236,7 +247,7 @@ class PandasConversionMixin(object):
elif type(dt) == DoubleType:
return np.float64
elif type(dt) == BooleanType:
- return np.bool
+ return np.bool # type: ignore[attr-defined]
elif type(dt) == TimestampType:
return np.datetime64
elif type(dt) == TimestampNTZType:
@@ -244,7 +255,7 @@ class PandasConversionMixin(object):
else:
return None
- def _collect_as_arrow(self, split_batches=False):
+ def _collect_as_arrow(self, split_batches: bool = False) ->
List["pa.RecordBatch"]:
"""
Returns all records as a list of ArrowRecordBatches, pyarrow must be
installed
and available on driver and worker Python environments.
@@ -260,7 +271,9 @@ class PandasConversionMixin(object):
assert isinstance(self, DataFrame)
with SCCallSiteSync(self._sc):
- port, auth_secret, jsocket_auth_server =
self._jdf.collectAsArrowToPython()
+ port, auth_secret, jsocket_auth_server = (
+ self._jdf.collectAsArrowToPython() # type: ignore[operator]
+ )
# Collect list of un-ordered batches where last element is a list of
correct order indices
try:
@@ -301,7 +314,29 @@ class SparkConversionMixin(object):
Min-in for the conversion from pandas to Spark. Currently, only
:class:`SparkSession`
can use this class.
"""
- def createDataFrame(self, data, schema=None, samplingRatio=None,
verifySchema=True):
+
+ @overload
+ def createDataFrame(
+ self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ...
+ ) -> "DataFrame":
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: "PandasDataFrameLike",
+ schema: Union[StructType, str],
+ verifySchema: bool = ...,
+ ) -> "DataFrame":
+ ...
+
+ def createDataFrame( # type: ignore[misc]
+ self,
+ data: "PandasDataFrameLike",
+ schema: Optional[Union[StructType, List[str]]] = None,
+ samplingRatio: Optional[float] = None,
+ verifySchema: bool = True
+ ) -> "DataFrame":
from pyspark.sql import SparkSession
assert isinstance(self, SparkSession)
@@ -309,19 +344,20 @@ class SparkConversionMixin(object):
from pyspark.sql.pandas.utils import require_minimum_pandas_version
require_minimum_pandas_version()
- timezone = self._wrapped._conf.sessionLocalTimeZone()
+ timezone = self._wrapped._conf.sessionLocalTimeZone() # type:
ignore[attr-defined]
# If no schema supplied by user then get the names of columns only
if schema is None:
- schema = [str(x) if not isinstance(x, str) else
- (x.encode('utf-8') if not isinstance(x, str) else x)
- for x in data.columns]
+ schema = [str(x) if not isinstance(x, str) else x for x in
data.columns]
- if self._wrapped._conf.arrowPySparkEnabled() and len(data) > 0:
+ if (
+ self._wrapped._conf.arrowPySparkEnabled() # type:
ignore[attr-defined]
+ and len(data) > 0
+ ):
try:
return self._create_from_pandas_with_arrow(data, schema,
timezone)
except Exception as e:
- if self._wrapped._conf.arrowPySparkFallbackEnabled():
+ if self._wrapped._conf.arrowPySparkFallbackEnabled(): # type:
ignore[attr-defined]
msg = (
"createDataFrame attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to
true; however, "
@@ -339,10 +375,17 @@ class SparkConversionMixin(object):
"has been set to false.\n %s" % str(e))
warnings.warn(msg)
raise
- data = self._convert_from_pandas(data, schema, timezone)
- return self._create_dataframe(data, schema, samplingRatio,
verifySchema)
-
- def _convert_from_pandas(self, pdf, schema, timezone):
+ converted_data = self._convert_from_pandas(data, schema, timezone)
+ return self._create_dataframe( # type: ignore[attr-defined]
+ converted_data, schema, samplingRatio, verifySchema
+ )
+
+ def _convert_from_pandas(
+ self,
+ pdf: "PandasDataFrameLike",
+ schema: Union[StructType, str, List[str]],
+ timezone: str
+ ) -> List:
"""
Convert a pandas.DataFrame to list of records that can be used to
make a DataFrame
@@ -398,7 +441,7 @@ class SparkConversionMixin(object):
# Convert list of numpy records to python lists
return [r.tolist() for r in np_records]
- def _get_numpy_record_dtype(self, rec):
+ def _get_numpy_record_dtype(self, rec: "np.recarray") ->
Optional["np.dtype"]:
"""
Used when converting a pandas.DataFrame to Spark using to_records(),
this will correct
the dtypes of fields in a record so they can be properly loaded into
Spark.
@@ -429,7 +472,12 @@ class SparkConversionMixin(object):
record_type_list.append((str(col_names[i]), curr_type))
return np.dtype(record_type_list) if has_rec_fix else None
- def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
+ def _create_from_pandas_with_arrow(
+ self,
+ pdf: "PandasDataFrameLike",
+ schema: Union[StructType, List[str]],
+ timezone: str
+ ) -> "DataFrame":
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into
partitions, converting
to Arrow data, then sending to the JVM to parallelize. If a schema is
passed in, the
@@ -483,27 +531,35 @@ class SparkConversionMixin(object):
arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(),
arrow_types)]
for pdf_slice in pdf_slices]
- jsqlContext = self._wrapped._jsqlContext
+ jsqlContext = self._wrapped._jsqlContext # type: ignore[attr-defined]
- safecheck = self._wrapped._conf.arrowSafeTypeConversion()
+ safecheck = self._wrapped._conf.arrowSafeTypeConversion() # type:
ignore[attr-defined]
col_by_name = True # col by name only applies to StructType columns,
can't happen here
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
+ @no_type_check
def reader_func(temp_filename):
return
self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
+ @no_type_check
def create_RDD_server():
return self._jvm.ArrowRDDServer(jsqlContext)
# Create Spark DataFrame from Arrow stream file, using one batch per
partition
- jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func,
create_RDD_server)
- jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(),
jsqlContext)
- df = DataFrame(jdf, self._wrapped)
- df._schema = schema
+ jrdd = (
+ self._sc # type: ignore[attr-defined]
+ ._serialize_to_jvm(arrow_data, ser, reader_func,
create_RDD_server)
+ )
+ jdf = (
+ self._jvm # type: ignore[attr-defined]
+ .PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
+ )
+ df = DataFrame(jdf, self._wrapped) # type: ignore[attr-defined]
+ df._schema = schema # type: ignore[attr-defined]
return df
-def _test():
+def _test() -> None:
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.pandas.conversion
diff --git a/python/pyspark/sql/pandas/conversion.pyi
b/python/pyspark/sql/pandas/conversion.pyi
deleted file mode 100644
index 87637722..0000000
--- a/python/pyspark/sql/pandas/conversion.pyi
+++ /dev/null
@@ -1,59 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import overload
-from typing import Optional, Union
-
-from pyspark.sql.pandas._typing import DataFrameLike
-from pyspark import since as since # noqa: F401
-from pyspark.rdd import RDD # noqa: F401
-import pyspark.sql.dataframe
-from pyspark.sql.pandas.serializers import ( # noqa: F401
- ArrowCollectSerializer as ArrowCollectSerializer,
-)
-from pyspark.sql.types import ( # noqa: F401
- BooleanType as BooleanType,
- ByteType as ByteType,
- DataType as DataType,
- DoubleType as DoubleType,
- FloatType as FloatType,
- IntegerType as IntegerType,
- IntegralType as IntegralType,
- LongType as LongType,
- ShortType as ShortType,
- StructType as StructType,
- TimestampType as TimestampType,
- TimestampNTZType as TimestampNTZType,
-)
-from pyspark.traceback_utils import SCCallSiteSync as SCCallSiteSync # noqa:
F401
-
-class PandasConversionMixin:
- def toPandas(self) -> DataFrameLike: ...
-
-class SparkConversionMixin:
- @overload
- def createDataFrame(
- self, data: DataFrameLike, samplingRatio: Optional[float] = ...
- ) -> pyspark.sql.dataframe.DataFrame: ...
- @overload
- def createDataFrame(
- self,
- data: DataFrameLike,
- schema: Union[StructType, str],
- verifySchema: bool = ...,
- ) -> pyspark.sql.dataframe.DataFrame: ...
diff --git a/python/pyspark/sql/pandas/group_ops.py
b/python/pyspark/sql/pandas/group_ops.py
index 8d4f67e..84db18f 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -15,11 +15,21 @@
# limitations under the License.
#
import sys
+from typing import List, Union, TYPE_CHECKING
import warnings
from pyspark.rdd import PythonEvalType
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
+from pyspark.sql.types import StructType
+
+if TYPE_CHECKING:
+ from pyspark.sql.pandas._typing import (
+ GroupedMapPandasUserDefinedFunction,
+ PandasGroupedMapFunction,
+ PandasCogroupedMapFunction,
+ )
+ from pyspark.sql.group import GroupedData
class PandasGroupedOpsMixin(object):
@@ -28,7 +38,7 @@ class PandasGroupedOpsMixin(object):
can use this class.
"""
- def apply(self, udf):
+ def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> DataFrame:
"""
It is an alias of :meth:`pyspark.sql.GroupedData.applyInPandas`;
however, it takes a
:meth:`pyspark.sql.functions.pandas_udf` whereas
@@ -73,8 +83,10 @@ class PandasGroupedOpsMixin(object):
pyspark.sql.functions.pandas_udf
"""
# Columns are special because hasattr always return True
- if isinstance(udf, Column) or not hasattr(udf, 'func') \
- or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
+ if isinstance(udf, Column) or not hasattr(udf, 'func') or (
+ udf.evalType # type: ignore[attr-defined]
+ != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+ ):
raise ValueError("Invalid udf: the udf argument must be a
pandas_udf of type "
"GROUPED_MAP.")
@@ -83,9 +95,11 @@ class PandasGroupedOpsMixin(object):
"API. This API will be deprecated in the future releases. See
SPARK-28264 for "
"more details.", UserWarning)
- return self.applyInPandas(udf.func, schema=udf.returnType)
+ return self.applyInPandas(udf.func, schema=udf.returnType) # type:
ignore[attr-defined]
- def applyInPandas(self, func, schema):
+ def applyInPandas(
+ self, func: "PandasGroupedMapFunction", schema: Union[StructType, str]
+ ) -> DataFrame:
"""
Maps each group of the current :class:`DataFrame` using a pandas udf
and returns the result
as a `DataFrame`.
@@ -197,12 +211,12 @@ class PandasGroupedOpsMixin(object):
udf = pandas_udf(
func, returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
- df = self._df
+ df = self._df # type: ignore[attr-defined]
udf_column = udf(*[df[col] for col in df.columns])
- jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
+ jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) # type:
ignore[attr-defined]
return DataFrame(jdf, self.sql_ctx)
- def cogroup(self, other):
+ def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
"""
Cogroups this group with another group so that we can run cogrouped
operations.
@@ -229,12 +243,14 @@ class PandasCogroupedOps(object):
This API is experimental.
"""
- def __init__(self, gd1, gd2):
+ def __init__(self, gd1: "GroupedData", gd2: "GroupedData"):
self._gd1 = gd1
self._gd2 = gd2
self.sql_ctx = gd1.sql_ctx
- def applyInPandas(self, func, schema):
+ def applyInPandas(
+ self, func: "PandasCogroupedMapFunction", schema: Union[StructType,
str]
+ ) -> DataFrame:
"""
Applies a function to each cogroup using pandas and returns the result
as a `DataFrame`.
@@ -329,16 +345,18 @@ class PandasCogroupedOps(object):
func, returnType=schema,
functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
all_cols = self._extract_cols(self._gd1) +
self._extract_cols(self._gd2)
udf_column = udf(*all_cols)
- jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd,
udf_column._jc.expr())
+ jdf = self._gd1._jgd.flatMapCoGroupsInPandas( # type:
ignore[attr-defined]
+ self._gd2._jgd, udf_column._jc.expr() # type: ignore[attr-defined]
+ )
return DataFrame(jdf, self.sql_ctx)
@staticmethod
- def _extract_cols(gd):
- df = gd._df
+ def _extract_cols(gd: "GroupedData") -> List[Column]:
+ df = gd._df # type: ignore[attr-defined]
return [df[col] for col in df.columns]
-def _test():
+def _test() -> None:
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.pandas.group_ops
diff --git a/python/pyspark/sql/pandas/group_ops.pyi
b/python/pyspark/sql/pandas/group_ops.pyi
deleted file mode 100644
index 2c543e0..0000000
--- a/python/pyspark/sql/pandas/group_ops.pyi
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import Union
-
-from pyspark.sql.pandas._typing import (
- GroupedMapPandasUserDefinedFunction,
- PandasGroupedMapFunction,
- PandasCogroupedMapFunction,
-)
-
-from pyspark import since as since # noqa: F401
-from pyspark.rdd import PythonEvalType as PythonEvalType # noqa: F401
-from pyspark.sql.column import Column as Column # noqa: F401
-from pyspark.sql.context import SQLContext
-import pyspark.sql.group
-from pyspark.sql.dataframe import DataFrame as DataFrame
-from pyspark.sql.types import StructType
-
-class PandasGroupedOpsMixin:
- def cogroup(self, other: pyspark.sql.group.GroupedData) ->
PandasCogroupedOps: ...
- def apply(self, udf: GroupedMapPandasUserDefinedFunction) -> DataFrame: ...
- def applyInPandas(
- self, func: PandasGroupedMapFunction, schema: Union[StructType, str]
- ) -> DataFrame: ...
-
-class PandasCogroupedOps:
- sql_ctx: SQLContext
- def __init__(
- self, gd1: pyspark.sql.group.GroupedData, gd2:
pyspark.sql.group.GroupedData
- ) -> None: ...
- def applyInPandas(
- self, func: PandasCogroupedMapFunction, schema: Union[StructType, str]
- ) -> DataFrame: ...
diff --git a/python/pyspark/sql/pandas/map_ops.py
b/python/pyspark/sql/pandas/map_ops.py
index 63fe371..21f81e3 100644
--- a/python/pyspark/sql/pandas/map_ops.py
+++ b/python/pyspark/sql/pandas/map_ops.py
@@ -15,8 +15,14 @@
# limitations under the License.
#
import sys
+from typing import Union, TYPE_CHECKING
from pyspark.rdd import PythonEvalType
+from pyspark.sql.types import StructType
+
+if TYPE_CHECKING:
+ from pyspark.sql.dataframe import DataFrame
+ from pyspark.sql.pandas._typing import PandasMapIterFunction
class PandasMapOpsMixin(object):
@@ -25,7 +31,9 @@ class PandasMapOpsMixin(object):
can use this class.
"""
- def mapInPandas(self, func, schema):
+ def mapInPandas(
+ self, func: "PandasMapIterFunction", schema: Union[StructType, str]
+ ) -> "DataFrame":
"""
Maps an iterator of batches in the current :class:`DataFrame` using a
Python native
function that takes and outputs a pandas DataFrame, and returns the
result as a
@@ -79,11 +87,11 @@ class PandasMapOpsMixin(object):
udf = pandas_udf(
func, returnType=schema,
functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
udf_column = udf(*[self[col] for col in self.columns])
- jdf = self._jdf.mapInPandas(udf_column._jc.expr())
+ jdf = self._jdf.mapInPandas(udf_column._jc.expr()) # type:
ignore[operator]
return DataFrame(jdf, self.sql_ctx)
-def _test():
+def _test() -> None:
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.pandas.map_ops
diff --git a/python/pyspark/sql/pandas/map_ops.pyi
b/python/pyspark/sql/pandas/map_ops.pyi
deleted file mode 100644
index cab8852..0000000
--- a/python/pyspark/sql/pandas/map_ops.pyi
+++ /dev/null
@@ -1,30 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import Union
-
-from pyspark.sql.pandas._typing import PandasMapIterFunction
-from pyspark import since as since # noqa: F401
-from pyspark.rdd import PythonEvalType as PythonEvalType # noqa: F401
-from pyspark.sql.types import StructType
-import pyspark.sql.dataframe
-
-class PandasMapOpsMixin:
- def mapInPandas(
- self, udf: PandasMapIterFunction, schema: Union[StructType, str]
- ) -> pyspark.sql.dataframe.DataFrame: ...
diff --git a/python/pyspark/sql/pandas/typehints.py
b/python/pyspark/sql/pandas/typehints.py
index e696f67..1990dd2 100644
--- a/python/pyspark/sql/pandas/typehints.py
+++ b/python/pyspark/sql/pandas/typehints.py
@@ -14,10 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+from inspect import Signature
+from typing import Any, Callable, Optional, Union, TYPE_CHECKING
+
from pyspark.sql.pandas.utils import require_minimum_pandas_version
+if TYPE_CHECKING:
+ from pyspark.sql.pandas._typing import (
+ PandasScalarUDFType, PandasScalarIterUDFType, PandasGroupedAggUDFType
+ )
+
-def infer_eval_type(sig):
+def infer_eval_type(
+ sig: Signature
+) -> Union["PandasScalarUDFType", "PandasScalarIterUDFType",
"PandasGroupedAggUDFType"]:
"""
Infers the evaluation type in :class:`pyspark.rdd.PythonEvalType` from
:class:`inspect.Signature` instance.
@@ -117,7 +127,9 @@ def infer_eval_type(sig):
raise NotImplementedError("Unsupported signature: %s." % sig)
-def check_tuple_annotation(annotation, parameter_check_func=None):
+def check_tuple_annotation(
+ annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] =
None
+) -> bool:
# Python 3.6 has `__name__`. Python 3.7 and 3.8 have `_name`.
# Check if the name is Tuple first. After that, check the generic types.
name = getattr(annotation, "_name", getattr(annotation, "__name__", None))
@@ -125,13 +137,17 @@ def check_tuple_annotation(annotation,
parameter_check_func=None):
parameter_check_func is None or all(map(parameter_check_func,
annotation.__args__)))
-def check_iterator_annotation(annotation, parameter_check_func=None):
+def check_iterator_annotation(
+ annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] =
None
+) -> bool:
name = getattr(annotation, "_name", getattr(annotation, "__name__", None))
return name == "Iterator" and (
parameter_check_func is None or all(map(parameter_check_func,
annotation.__args__)))
-def check_union_annotation(annotation, parameter_check_func=None):
+def check_union_annotation(
+ annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] =
None
+) -> bool:
import typing
# Note that we cannot rely on '__origin__' in other type hints as it has
changed from version
diff --git a/python/pyspark/sql/pandas/types.py
b/python/pyspark/sql/pandas/types.py
index ceb71a3..44253bf 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -19,13 +19,19 @@
Type-specific codes between pandas and PyArrow. Also contains some utils to
correct
pandas instances during the type conversion.
"""
+from typing import Optional, TYPE_CHECKING
from pyspark.sql.types import BooleanType, ByteType, ShortType, IntegerType,
LongType, \
FloatType, DoubleType, DecimalType, StringType, BinaryType, DateType,
TimestampType, \
- TimestampNTZType, ArrayType, MapType, StructType, StructField, NullType
+ TimestampNTZType, ArrayType, MapType, StructType, StructField, NullType,
DataType
+
+if TYPE_CHECKING:
+ import pyarrow as pa
+
+ from pyspark.sql.pandas._typing import SeriesLike as PandasSeriesLike
-def to_arrow_type(dt):
+def to_arrow_type(dt: DataType) -> "pa.DataType":
""" Convert Spark data type to pyarrow type
"""
from distutils.version import LooseVersion
@@ -81,7 +87,7 @@ def to_arrow_type(dt):
return arrow_type
-def to_arrow_schema(schema):
+def to_arrow_schema(schema: StructType) -> "pa.Schema":
""" Convert a schema from Spark to Arrow
"""
import pyarrow as pa
@@ -90,14 +96,14 @@ def to_arrow_schema(schema):
return pa.schema(fields)
-def from_arrow_type(at, prefer_timestamp_ntz=False):
+def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) ->
DataType:
""" Convert pyarrow type to Spark data type.
"""
from distutils.version import LooseVersion
import pyarrow as pa
import pyarrow.types as types
if types.is_boolean(at):
- spark_type = BooleanType()
+ spark_type = BooleanType() # type: DataType
elif types.is_int8(at):
spark_type = ByteType()
elif types.is_int16(at):
@@ -147,7 +153,7 @@ def from_arrow_type(at, prefer_timestamp_ntz=False):
return spark_type
-def from_arrow_schema(arrow_schema):
+def from_arrow_schema(arrow_schema: "pa.Schema") -> StructType:
""" Convert schema from Arrow to Spark.
"""
return StructType(
@@ -155,7 +161,7 @@ def from_arrow_schema(arrow_schema):
for field in arrow_schema])
-def _get_local_timezone():
+def _get_local_timezone() -> str:
""" Get local timezone using pytz with environment variable, or dateutil.
If there is a 'TZ' environment variable, pass it to pandas to use pytz and
use it as timezone
@@ -170,7 +176,7 @@ def _get_local_timezone():
return os.environ.get('TZ', 'dateutil/:')
-def _check_series_localize_timestamps(s, timezone):
+def _check_series_localize_timestamps(s: "PandasSeriesLike", timezone: str) ->
"PandasSeriesLike":
"""
Convert timezone aware timestamps to timezone-naive in the specified
timezone or local timezone.
@@ -200,7 +206,9 @@ def _check_series_localize_timestamps(s, timezone):
return s
-def _check_series_convert_timestamps_internal(s, timezone):
+def _check_series_convert_timestamps_internal(
+ s: "PandasSeriesLike", timezone: str
+) -> "PandasSeriesLike":
"""
Convert a tz-naive timestamp in the specified timezone or local timezone
to UTC normalized for
Spark internal storage
@@ -260,7 +268,9 @@ def _check_series_convert_timestamps_internal(s, timezone):
return s
-def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone):
+def _check_series_convert_timestamps_localize(
+ s: "PandasSeriesLike", from_timezone: Optional[str], to_timezone:
Optional[str]
+) -> "PandasSeriesLike":
"""
Convert timestamp to timezone-naive in the specified timezone or local
timezone
@@ -296,7 +306,9 @@ def _check_series_convert_timestamps_localize(s,
from_timezone, to_timezone):
return s
-def _check_series_convert_timestamps_local_tz(s, timezone):
+def _check_series_convert_timestamps_local_tz(
+ s: "PandasSeriesLike", timezone: str
+) -> "PandasSeriesLike":
"""
Convert timestamp to timezone-naive in the specified timezone or local
timezone
@@ -314,7 +326,9 @@ def _check_series_convert_timestamps_local_tz(s, timezone):
return _check_series_convert_timestamps_localize(s, None, timezone)
-def _check_series_convert_timestamps_tz_local(s, timezone):
+def _check_series_convert_timestamps_tz_local(
+ s: "PandasSeriesLike", timezone: str
+) -> "PandasSeriesLike":
"""
Convert timestamp to timezone-naive in the specified timezone or local
timezone
@@ -332,7 +346,7 @@ def _check_series_convert_timestamps_tz_local(s, timezone):
return _check_series_convert_timestamps_localize(s, timezone, None)
-def _convert_map_items_to_dict(s):
+def _convert_map_items_to_dict(s: "PandasSeriesLike") -> "PandasSeriesLike":
"""
Convert a series with items as list of (key, value), as made from an Arrow
column of map type,
to dict for compatibility with non-arrow MapType columns.
@@ -342,7 +356,7 @@ def _convert_map_items_to_dict(s):
return s.apply(lambda m: None if m is None else {k: v for k, v in m})
-def _convert_dict_to_map_items(s):
+def _convert_dict_to_map_items(s: "PandasSeriesLike") -> "PandasSeriesLike":
"""
Convert a series of dictionaries to list of (key, value) pairs to match
expected data
for Arrow column of map type.
diff --git a/python/pyspark/sql/pandas/utils.py
b/python/pyspark/sql/pandas/utils.py
index b22603a..4144929 100644
--- a/python/pyspark/sql/pandas/utils.py
+++ b/python/pyspark/sql/pandas/utils.py
@@ -16,7 +16,7 @@
#
-def require_minimum_pandas_version():
+def require_minimum_pandas_version() -> None:
""" Raise ImportError if minimum version of Pandas is not installed
"""
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
@@ -37,7 +37,7 @@ def require_minimum_pandas_version():
"your version was %s." % (minimum_pandas_version,
pandas.__version__))
-def require_minimum_pyarrow_version():
+def require_minimum_pyarrow_version() -> None:
""" Raise ImportError if minimum version of pyarrow is not installed
"""
# TODO(HyukjinKwon): Relocate and deduplicate the version specification.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]