ueshin commented on a change in pull request #34136:
URL: https://github.com/apache/spark/pull/34136#discussion_r718879032
##########
File path: python/pyspark/sql/session.py
##########
@@ -19,24 +19,46 @@
import warnings
from functools import reduce
from threading import RLock
+from types import TracebackType
+from typing import (
+ Any, Dict, Iterable, List, Optional, Tuple, Type, Union,
+ cast, no_type_check, overload, TYPE_CHECKING
+)
-from pyspark import since
+from py4j.java_gateway import JavaObject # type: ignore[import]
+
+from pyspark import SparkConf, SparkContext, since
from pyspark.rdd import RDD
from pyspark.sql.conf import RuntimeConfig
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas.conversion import SparkConversionMixin
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
-from pyspark.sql.types import DataType, StructType, \
- _make_type_verifier, _infer_schema, _has_nulltype, _merge_type,
_create_converter, \
+from pyspark.sql.types import ( # type: ignore[attr-defined]
+ AtomicType, DataType, StructType,
+ _make_type_verifier, _infer_schema, _has_nulltype, _merge_type,
_create_converter,
_parse_datatype_string
+)
from pyspark.sql.utils import install_exception_handler,
is_timestamp_ntz_preferred
+if TYPE_CHECKING:
+ from pyspark.sql._typing import DateTimeLiteral, LiteralType,
DecimalLiteral, RowLike
+ from pyspark.sql.catalog import Catalog
+ from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+ from pyspark.sql.streaming import StreamingQueryManager
+ from pyspark.sql.udf import UDFRegistration
+
+
__all__ = ["SparkSession"]
-def _monkey_patch_RDD(sparkSession):
- def toDF(self, schema=None, sampleRatio=None):
+def _monkey_patch_RDD(sparkSession: "SparkSession") -> None:
+
+ def toDF(
+ self: "RDD[RowLike]",
+ schema: Optional[Union[List[str], Tuple[str, ...]]] = None,
Review comment:
Ah, cool. I missed there are the annotations in `rdd.pyi`.
I guess we can just mark it `@no_type_check` here. Thanks!
##########
File path: python/pyspark/sql/session.py
##########
@@ -19,24 +19,46 @@
import warnings
from functools import reduce
from threading import RLock
+from types import TracebackType
+from typing import (
+ Any, Dict, Iterable, List, Optional, Tuple, Type, Union,
+ cast, no_type_check, overload, TYPE_CHECKING
+)
-from pyspark import since
+from py4j.java_gateway import JavaObject # type: ignore[import]
+
+from pyspark import SparkConf, SparkContext, since
from pyspark.rdd import RDD
from pyspark.sql.conf import RuntimeConfig
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas.conversion import SparkConversionMixin
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
-from pyspark.sql.types import DataType, StructType, \
- _make_type_verifier, _infer_schema, _has_nulltype, _merge_type,
_create_converter, \
+from pyspark.sql.types import ( # type: ignore[attr-defined]
+ AtomicType, DataType, StructType,
+ _make_type_verifier, _infer_schema, _has_nulltype, _merge_type,
_create_converter,
_parse_datatype_string
+)
from pyspark.sql.utils import install_exception_handler,
is_timestamp_ntz_preferred
+if TYPE_CHECKING:
+ from pyspark.sql._typing import DateTimeLiteral, LiteralType,
DecimalLiteral, RowLike
+ from pyspark.sql.catalog import Catalog
+ from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+ from pyspark.sql.streaming import StreamingQueryManager
+ from pyspark.sql.udf import UDFRegistration
+
+
__all__ = ["SparkSession"]
-def _monkey_patch_RDD(sparkSession):
- def toDF(self, schema=None, sampleRatio=None):
+def _monkey_patch_RDD(sparkSession: "SparkSession") -> None:
+
+ def toDF(
+ self: "RDD[RowLike]",
+ schema: Optional[Union[List[str], Tuple[str, ...]]] = None,
Review comment:
Ah, cool. I missed there are the annotations in `rdd.pyi`.
I guess we can just mark it `@no_type_check` here for now. Thanks!
##########
File path: python/pyspark/sql/session.py
##########
@@ -19,24 +19,46 @@
import warnings
from functools import reduce
from threading import RLock
+from types import TracebackType
+from typing import (
+ Any, Dict, Iterable, List, Optional, Tuple, Type, Union,
+ cast, no_type_check, overload, TYPE_CHECKING
+)
-from pyspark import since
+from py4j.java_gateway import JavaObject # type: ignore[import]
+
+from pyspark import SparkConf, SparkContext, since
from pyspark.rdd import RDD
from pyspark.sql.conf import RuntimeConfig
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas.conversion import SparkConversionMixin
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
-from pyspark.sql.types import DataType, StructType, \
- _make_type_verifier, _infer_schema, _has_nulltype, _merge_type,
_create_converter, \
+from pyspark.sql.types import ( # type: ignore[attr-defined]
+ AtomicType, DataType, StructType,
+ _make_type_verifier, _infer_schema, _has_nulltype, _merge_type,
_create_converter,
_parse_datatype_string
+)
from pyspark.sql.utils import install_exception_handler,
is_timestamp_ntz_preferred
+if TYPE_CHECKING:
+ from pyspark.sql._typing import DateTimeLiteral, LiteralType,
DecimalLiteral, RowLike
+ from pyspark.sql.catalog import Catalog
+ from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+ from pyspark.sql.streaming import StreamingQueryManager
+ from pyspark.sql.udf import UDFRegistration
+
+
__all__ = ["SparkSession"]
-def _monkey_patch_RDD(sparkSession):
- def toDF(self, schema=None, sampleRatio=None):
+def _monkey_patch_RDD(sparkSession: "SparkSession") -> None:
+
+ def toDF(
+ self: "RDD[RowLike]",
+ schema: Optional[Union[List[str], Tuple[str, ...]]] = None,
Review comment:
> On a side note, we're missing `schema: str` variants, if I am not
mistaken.
@zero323 May I ask you to fix the missing variant?
##########
File path: python/pyspark/sql/session.py
##########
@@ -19,24 +19,46 @@
import warnings
from functools import reduce
from threading import RLock
+from types import TracebackType
+from typing import (
+ Any, Dict, Iterable, List, Optional, Tuple, Type, Union,
+ cast, no_type_check, overload, TYPE_CHECKING
+)
-from pyspark import since
+from py4j.java_gateway import JavaObject # type: ignore[import]
+
+from pyspark import SparkConf, SparkContext, since
from pyspark.rdd import RDD
from pyspark.sql.conf import RuntimeConfig
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas.conversion import SparkConversionMixin
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
-from pyspark.sql.types import DataType, StructType, \
- _make_type_verifier, _infer_schema, _has_nulltype, _merge_type,
_create_converter, \
+from pyspark.sql.types import ( # type: ignore[attr-defined]
+ AtomicType, DataType, StructType,
+ _make_type_verifier, _infer_schema, _has_nulltype, _merge_type,
_create_converter,
_parse_datatype_string
+)
from pyspark.sql.utils import install_exception_handler,
is_timestamp_ntz_preferred
+if TYPE_CHECKING:
+ from pyspark.sql._typing import DateTimeLiteral, LiteralType,
DecimalLiteral, RowLike
+ from pyspark.sql.catalog import Catalog
+ from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+ from pyspark.sql.streaming import StreamingQueryManager
+ from pyspark.sql.udf import UDFRegistration
+
+
__all__ = ["SparkSession"]
-def _monkey_patch_RDD(sparkSession):
- def toDF(self, schema=None, sampleRatio=None):
+def _monkey_patch_RDD(sparkSession: "SparkSession") -> None:
+
+ def toDF(
+ self: "RDD[RowLike]",
+ schema: Optional[Union[List[str], Tuple[str, ...]]] = None,
Review comment:
> On a side note, we're missing `schema: str` variants, if I am not
mistaken.
@zero323 May I ask you to fix the missing variants?
##########
File path: python/pyspark/sql/session.py
##########
@@ -19,24 +19,46 @@
import warnings
from functools import reduce
from threading import RLock
+from types import TracebackType
+from typing import (
+ Any, Dict, Iterable, List, Optional, Tuple, Type, Union,
+ cast, no_type_check, overload, TYPE_CHECKING
+)
-from pyspark import since
+from py4j.java_gateway import JavaObject # type: ignore[import]
+
+from pyspark import SparkConf, SparkContext, since
from pyspark.rdd import RDD
from pyspark.sql.conf import RuntimeConfig
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.pandas.conversion import SparkConversionMixin
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
-from pyspark.sql.types import DataType, StructType, \
- _make_type_verifier, _infer_schema, _has_nulltype, _merge_type,
_create_converter, \
+from pyspark.sql.types import ( # type: ignore[attr-defined]
+ AtomicType, DataType, StructType,
+ _make_type_verifier, _infer_schema, _has_nulltype, _merge_type,
_create_converter,
_parse_datatype_string
+)
from pyspark.sql.utils import install_exception_handler,
is_timestamp_ntz_preferred
+if TYPE_CHECKING:
+ from pyspark.sql._typing import DateTimeLiteral, LiteralType,
DecimalLiteral, RowLike
+ from pyspark.sql.catalog import Catalog
+ from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+ from pyspark.sql.streaming import StreamingQueryManager
+ from pyspark.sql.udf import UDFRegistration
+
+
__all__ = ["SparkSession"]
-def _monkey_patch_RDD(sparkSession):
- def toDF(self, schema=None, sampleRatio=None):
+def _monkey_patch_RDD(sparkSession: "SparkSession") -> None:
+
+ def toDF(
+ self: "RDD[RowLike]",
+ schema: Optional[Union[List[str], Tuple[str, ...]]] = None,
Review comment:
Thanks!
##########
File path: python/pyspark/sql/session.py
##########
@@ -566,7 +629,70 @@ def _create_shell_session():
return SparkSession.builder.getOrCreate()
- def createDataFrame(self, data, schema=None, samplingRatio=None,
verifySchema=True):
+ @overload
+ def createDataFrame(
+ self,
+ data: Union["RDD[RowLike]", Iterable["RowLike"]],
+ samplingRatio: Optional[float] = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: Union["RDD[RowLike]", Iterable["RowLike"]],
+ schema: Union[List[str], Tuple[str, ...]] = ...,
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: Union[
+ "RDD[Union[DateTimeLiteral, LiteralType, DecimalLiteral]]",
+ Iterable[Union["DateTimeLiteral", "LiteralType",
"DecimalLiteral"]],
+ ],
+ schema: Union[AtomicType, str],
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: Union["RDD[RowLike]", Iterable["RowLike"]],
+ schema: Union[StructType, str],
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ...
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: "PandasDataFrameLike",
+ schema: Union[StructType, str],
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ def createDataFrame( # type: ignore[misc]
+ self,
+ data: Union[
+ "RDD[Union[DateTimeLiteral, LiteralType, DecimalLiteral,
RowLike]]",
+ Iterable[Union["DateTimeLiteral", "LiteralType", "DecimalLiteral",
"RowLike"]],
+ "PandasDataFrameLike",
+ ],
+ schema: Optional[Union[AtomicType, StructType, str]] = None,
+ samplingRatio: Optional[float] = None,
+ verifySchema: bool = True
+ ) -> DataFrame:
Review comment:
For overloaded functions, the actual function that has the function body
is not exposed to the type checking libraries.
So the type checking libraries should still raise such an error.
The type hints for the actual function are purely for mypy to check the
function body.
##########
File path: python/pyspark/sql/session.py
##########
@@ -566,7 +629,70 @@ def _create_shell_session():
return SparkSession.builder.getOrCreate()
- def createDataFrame(self, data, schema=None, samplingRatio=None,
verifySchema=True):
+ @overload
+ def createDataFrame(
+ self,
+ data: Union["RDD[RowLike]", Iterable["RowLike"]],
+ samplingRatio: Optional[float] = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: Union["RDD[RowLike]", Iterable["RowLike"]],
+ schema: Union[List[str], Tuple[str, ...]] = ...,
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: Union[
+ "RDD[Union[DateTimeLiteral, LiteralType, DecimalLiteral]]",
+ Iterable[Union["DateTimeLiteral", "LiteralType",
"DecimalLiteral"]],
+ ],
+ schema: Union[AtomicType, str],
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: Union["RDD[RowLike]", Iterable["RowLike"]],
+ schema: Union[StructType, str],
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ...
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: "PandasDataFrameLike",
+ schema: Union[StructType, str],
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ def createDataFrame( # type: ignore[misc]
+ self,
+ data: Union[
+ "RDD[Union[DateTimeLiteral, LiteralType, DecimalLiteral,
RowLike]]",
+ Iterable[Union["DateTimeLiteral", "LiteralType", "DecimalLiteral",
"RowLike"]],
+ "PandasDataFrameLike",
+ ],
+ schema: Optional[Union[AtomicType, StructType, str]] = None,
+ samplingRatio: Optional[float] = None,
+ verifySchema: bool = True
+ ) -> DataFrame:
Review comment:
If we remove the annotations, `mypy` won't check the function body.
To make `mypy` check the function body is the purpose of this series of PRs,
then we can more easily catch the misuse of variables.
##########
File path: python/pyspark/sql/session.py
##########
@@ -566,7 +629,70 @@ def _create_shell_session():
return SparkSession.builder.getOrCreate()
- def createDataFrame(self, data, schema=None, samplingRatio=None,
verifySchema=True):
+ @overload
+ def createDataFrame(
+ self,
+ data: Union["RDD[RowLike]", Iterable["RowLike"]],
+ samplingRatio: Optional[float] = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: Union["RDD[RowLike]", Iterable["RowLike"]],
+ schema: Union[List[str], Tuple[str, ...]] = ...,
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: Union[
+ "RDD[Union[DateTimeLiteral, LiteralType, DecimalLiteral]]",
+ Iterable[Union["DateTimeLiteral", "LiteralType",
"DecimalLiteral"]],
+ ],
+ schema: Union[AtomicType, str],
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: Union["RDD[RowLike]", Iterable["RowLike"]],
+ schema: Union[StructType, str],
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ...
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: "PandasDataFrameLike",
+ schema: Union[StructType, str],
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ def createDataFrame( # type: ignore[misc]
+ self,
+ data: Union[
+ "RDD[Union[DateTimeLiteral, LiteralType, DecimalLiteral,
RowLike]]",
+ Iterable[Union["DateTimeLiteral", "LiteralType", "DecimalLiteral",
"RowLike"]],
+ "PandasDataFrameLike",
+ ],
+ schema: Optional[Union[AtomicType, StructType, str]] = None,
+ samplingRatio: Optional[float] = None,
+ verifySchema: bool = True
+ ) -> DataFrame:
Review comment:
If we remove the annotations, `mypy` won't check the function body.
To make `mypy` check the function body is one of the purposes of this series
of PRs, then we can more easily catch the misuse of variables.
##########
File path: python/pyspark/sql/session.py
##########
@@ -566,7 +629,70 @@ def _create_shell_session():
return SparkSession.builder.getOrCreate()
- def createDataFrame(self, data, schema=None, samplingRatio=None,
verifySchema=True):
+ @overload
+ def createDataFrame(
+ self,
+ data: Union["RDD[RowLike]", Iterable["RowLike"]],
+ samplingRatio: Optional[float] = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: Union["RDD[RowLike]", Iterable["RowLike"]],
+ schema: Union[List[str], Tuple[str, ...]] = ...,
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: Union[
+ "RDD[Union[DateTimeLiteral, LiteralType, DecimalLiteral]]",
+ Iterable[Union["DateTimeLiteral", "LiteralType",
"DecimalLiteral"]],
+ ],
+ schema: Union[AtomicType, str],
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: Union["RDD[RowLike]", Iterable["RowLike"]],
+ schema: Union[StructType, str],
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ...
+ ) -> DataFrame:
+ ...
+
+ @overload
+ def createDataFrame(
+ self,
+ data: "PandasDataFrameLike",
+ schema: Union[StructType, str],
+ verifySchema: bool = ...,
+ ) -> DataFrame:
+ ...
+
+ def createDataFrame( # type: ignore[misc]
+ self,
+ data: Union[
+ "RDD[Union[DateTimeLiteral, LiteralType, DecimalLiteral,
RowLike]]",
+ Iterable[Union["DateTimeLiteral", "LiteralType", "DecimalLiteral",
"RowLike"]],
+ "PandasDataFrameLike",
+ ],
+ schema: Optional[Union[AtomicType, StructType, str]] = None,
+ samplingRatio: Optional[float] = None,
+ verifySchema: bool = True
+ ) -> DataFrame:
Review comment:
Thank YOU for asking!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]