zero323 commented on a change in pull request #34466:
URL: https://github.com/apache/spark/pull/34466#discussion_r753785413
##########
File path: python/pyspark/accumulators.py
##########
@@ -46,7 +48,10 @@ def _deserialize_accumulator(aid, zero_value, accum_param):
return accum
-class Accumulator(object):
+T = TypeVar("T")
Review comment:
Let's not touch things which are not strictly related to this particular
PR. Not to mention we should never have annotations in both `py` and `pyi`
files.
##########
File path: python/pyspark/broadcast.py
##########
@@ -42,7 +43,10 @@ def _from_id(bid):
return _broadcastRegistry[bid]
-class Broadcast(object):
+T = TypeVar("T")
+
+
+class Broadcast(Generic[T], object):
Review comment:
Ditto
##########
File path: python/pyspark/context.py
##########
@@ -125,30 +144,36 @@ class SparkContext(object):
ValueError: ...
"""
- _gateway = None
- _jvm = None
+ _gateway: JavaGateway = None
+ _jvm: JVMView = None
_next_accum_id = 0
- _active_spark_context = None
+ _active_spark_context: Optional["SparkContext"] = None
_lock = RLock()
- _python_includes = None # zip and egg files that need to be added to
PYTHONPATH
-
+ _python_includes: Optional[
+ List[str]
+ ] = None # zip and egg files that need to be added to PYTHONPATH
+ profiler_collector: Optional[ProfilerCollector]
+ serializer: Serializer
Review comment:
These should be marked as `ClassVar`.
##########
File path: python/pyspark/context.py
##########
@@ -180,20 +205,20 @@ def __init__(
def _do_init(
self,
- master,
- appName,
- sparkHome,
- pyFiles,
- environment,
- batchSize,
- serializer,
- conf,
- jsc,
- profiler_cls,
- ):
+ master: Optional[str],
+ appName: Optional[str],
+ sparkHome: Optional[str],
+ pyFiles: Optional[List[str]],
+ environment: Optional[Dict[str, str]],
+ batchSize: int,
+ serializer: Serializer,
+ conf: Optional[SparkConf],
+ jsc: JavaObject,
+ profiler_cls: type,
+ ) -> Any:
Review comment:
`_do_init` is used only for side effects and returns nothing, so should
be `-> None`.
##########
File path: python/pyspark/context.py
##########
@@ -221,7 +246,7 @@ def _do_init(
if environment:
for key, value in environment.items():
self._conf.setExecutorEnv(key, value)
- for key, value in DEFAULT_CONFIGS.items():
+ for key, value in DEFAULT_CONFIGS.items(): # type: ignore[assignment]
Review comment:
Instead of ignoring here, we can add
```str
DEFAULT_CONFIGS: Dict[str, Any] = {
```
and, for consistency, adjust `environment` in `__init__` and `_do__init` as
well.
##########
File path: python/pyspark/context.py
##########
@@ -244,26 +269,32 @@ def _do_init(
self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED",
"0")
# Create the Java SparkContext through Py4J
- self._jsc = jsc or self._initialize_context(self._conf._jconf)
+ self._jsc = jsc or self._initialize_context(self._conf._jconf) #
type: ignore[attr-defined]
Review comment:
We can add this to the corresponding stub file, so we can skip ignore,
but I won't insist on that.
##########
File path: python/pyspark/context.py
##########
@@ -244,26 +269,32 @@ def _do_init(
self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED",
"0")
# Create the Java SparkContext through Py4J
- self._jsc = jsc or self._initialize_context(self._conf._jconf)
+ self._jsc = jsc or self._initialize_context(self._conf._jconf) #
type: ignore[attr-defined]
# Reset the SparkConf to the one actually used by the SparkContext in
JVM.
self._conf = SparkConf(_jconf=self._jsc.sc().conf())
# Create a single Accumulator in Java that we'll send all our updates
through;
# they will be passed back to us through a TCP server
- auth_token = self._gateway.gateway_parameters.auth_token
- self._accumulatorServer = accumulators._start_update_server(auth_token)
+ auth_token = cast(JVMView, self._gateway).gateway_parameters.auth_token
+ self._accumulatorServer = accumulators._start_update_server( # type:
ignore[attr-defined]
Review comment:
Ditto
##########
File path: python/pyspark/context.py
##########
@@ -244,26 +269,32 @@ def _do_init(
self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED",
"0")
# Create the Java SparkContext through Py4J
- self._jsc = jsc or self._initialize_context(self._conf._jconf)
+ self._jsc = jsc or self._initialize_context(self._conf._jconf) #
type: ignore[attr-defined]
# Reset the SparkConf to the one actually used by the SparkContext in
JVM.
self._conf = SparkConf(_jconf=self._jsc.sc().conf())
# Create a single Accumulator in Java that we'll send all our updates
through;
# they will be passed back to us through a TCP server
- auth_token = self._gateway.gateway_parameters.auth_token
- self._accumulatorServer = accumulators._start_update_server(auth_token)
+ auth_token = cast(JVMView, self._gateway).gateway_parameters.auth_token
+ self._accumulatorServer = accumulators._start_update_server( # type:
ignore[attr-defined]
+ auth_token
+ ) # type: ignore[attr-defined]
Review comment:
Nit: We should need only one of these ignores.
##########
File path: python/pyspark/context.py
##########
@@ -244,26 +269,32 @@ def _do_init(
self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED",
"0")
# Create the Java SparkContext through Py4J
- self._jsc = jsc or self._initialize_context(self._conf._jconf)
+ self._jsc = jsc or self._initialize_context(self._conf._jconf) #
type: ignore[attr-defined]
# Reset the SparkConf to the one actually used by the SparkContext in
JVM.
self._conf = SparkConf(_jconf=self._jsc.sc().conf())
# Create a single Accumulator in Java that we'll send all our updates
through;
# they will be passed back to us through a TCP server
- auth_token = self._gateway.gateway_parameters.auth_token
- self._accumulatorServer = accumulators._start_update_server(auth_token)
+ auth_token = cast(JVMView, self._gateway).gateway_parameters.auth_token
Review comment:
`_gateway` is `ClientServer` not `JVMView`.
##########
File path: python/pyspark/context.py
##########
@@ -244,26 +269,32 @@ def _do_init(
self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED",
"0")
# Create the Java SparkContext through Py4J
- self._jsc = jsc or self._initialize_context(self._conf._jconf)
+ self._jsc = jsc or self._initialize_context(self._conf._jconf) #
type: ignore[attr-defined]
# Reset the SparkConf to the one actually used by the SparkContext in
JVM.
self._conf = SparkConf(_jconf=self._jsc.sc().conf())
# Create a single Accumulator in Java that we'll send all our updates
through;
# they will be passed back to us through a TCP server
- auth_token = self._gateway.gateway_parameters.auth_token
- self._accumulatorServer = accumulators._start_update_server(auth_token)
+ auth_token = cast(JVMView, self._gateway).gateway_parameters.auth_token
+ self._accumulatorServer = accumulators._start_update_server( # type:
ignore[attr-defined]
+ auth_token
+ ) # type: ignore[attr-defined]
(host, port) = self._accumulatorServer.server_address
- self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port,
auth_token)
+ self._javaAccumulator = cast(JVMView,
self._jvm).PythonAccumulatorV2(host, port, auth_token)
Review comment:
Given how many of these casts we have in this scope, it might be better
to extract a local variable
```python
jvm = cast(JVMView, self._jvm)
```
##########
File path: python/pyspark/context.py
##########
@@ -324,21 +359,23 @@ def _do_init(
self.profiler_collector = None
# create a signal handler which would be invoked on receiving SIGINT
- def signal_handler(signal, frame):
+ def signal_handler(signal: Any, frame: Any) -> None:
self.cancelAllJobs()
raise KeyboardInterrupt()
Review comment:
Might be overkill, but you can take a look at
https://github.com/python/typeshed/blob/ee487304d76c671aa353a15e34bc1102bffa2362/stdlib/signal.pyi#L188
and
https://github.com/python/typeshed/blob/ee487304d76c671aa353a15e34bc1102bffa2362/stdlib/signal.pyi#L83
Also it might be `NoReturn` in our case.
##########
File path: python/pyspark/context.py
##########
@@ -324,21 +359,23 @@ def _do_init(
self.profiler_collector = None
# create a signal handler which would be invoked on receiving SIGINT
- def signal_handler(signal, frame):
+ def signal_handler(signal: Any, frame: Any) -> None:
Review comment:
Might be overkill, but you can take a look at
https://github.com/python/typeshed/blob/ee487304d76c671aa353a15e34bc1102bffa2362/stdlib/signal.pyi#L188
and
https://github.com/python/typeshed/blob/ee487304d76c671aa353a15e34bc1102bffa2362/stdlib/signal.pyi#L83
Also it might be `NoReturn` in our case.
##########
File path: python/pyspark/context.py
##########
@@ -324,21 +359,23 @@ def _do_init(
self.profiler_collector = None
# create a signal handler which would be invoked on receiving SIGINT
- def signal_handler(signal, frame):
+ def signal_handler(signal: Any, frame: Any) -> None:
self.cancelAllJobs()
raise KeyboardInterrupt()
# see http://stackoverflow.com/questions/23206787/
- if isinstance(threading.current_thread(), threading._MainThread):
+ if isinstance(
+ threading.current_thread(), threading._MainThread # type:
ignore[attr-defined]
+ ): # type: ignore[attr-defined]
Review comment:
Only one of these should be necessary,
##########
File path: python/pyspark/context.py
##########
@@ -244,26 +269,32 @@ def _do_init(
self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED",
"0")
# Create the Java SparkContext through Py4J
- self._jsc = jsc or self._initialize_context(self._conf._jconf)
+ self._jsc = jsc or self._initialize_context(self._conf._jconf) #
type: ignore[attr-defined]
# Reset the SparkConf to the one actually used by the SparkContext in
JVM.
self._conf = SparkConf(_jconf=self._jsc.sc().conf())
# Create a single Accumulator in Java that we'll send all our updates
through;
# they will be passed back to us through a TCP server
- auth_token = self._gateway.gateway_parameters.auth_token
- self._accumulatorServer = accumulators._start_update_server(auth_token)
+ auth_token = cast(JVMView, self._gateway).gateway_parameters.auth_token
Review comment:
`_gateway` is `ClientServer` or `JavaGateway`, not `JVMView`.
##########
File path: python/pyspark/context.py
##########
@@ -434,17 +481,17 @@ def getOrCreate(cls, conf=None):
with SparkContext._lock:
if SparkContext._active_spark_context is None:
SparkContext(conf=conf or SparkConf())
- return SparkContext._active_spark_context
+ return cast("SparkContext", SparkContext._active_spark_context)
Review comment:
I'd `assert SparkContext is not None` here.
##########
File path: python/pyspark/context.py
##########
@@ -565,9 +614,9 @@ def range(self, start, end=None, step=1, numSlices=None):
end = start
start = 0
- return self.parallelize(range(start, end, step), numSlices)
+ return self.parallelize(list(range(start, end, step)), numSlices)
- def parallelize(self, c, numSlices=None):
+ def parallelize(self, c: List[T], numSlices: Optional[int] = None) ->
RDD[T]:
Review comment:
We did you change the signature here?
https://github.com/apache/spark/blob/ef4f2546c58ef5fe67be7047f9aa2a793519fd54/python/pyspark/context.pyi#L100
##########
File path: python/pyspark/context.py
##########
@@ -565,9 +614,9 @@ def range(self, start, end=None, step=1, numSlices=None):
end = start
start = 0
- return self.parallelize(range(start, end, step), numSlices)
+ return self.parallelize(list(range(start, end, step)), numSlices)
- def parallelize(self, c, numSlices=None):
+ def parallelize(self, c: List[T], numSlices: Optional[int] = None) ->
RDD[T]:
Review comment:
Why did you change the signature here?
https://github.com/apache/spark/blob/ef4f2546c58ef5fe67be7047f9aa2a793519fd54/python/pyspark/context.pyi#L100
##########
File path: python/pyspark/context.py
##########
@@ -587,10 +636,10 @@ def parallelize(self, c, numSlices=None):
step = c[1] - c[0] if size > 1 else 1
start0 = c[0]
- def getStart(split):
+ def getStart(split: T) -> T:
Review comment:
At first glance, this should be `(int) -> int`.
##########
File path: python/pyspark/context.py
##########
@@ -609,16 +658,20 @@ def f(split, iterator):
batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
- def reader_func(temp_filename):
- return self._jvm.PythonRDD.readRDDFromFile(self._jsc,
temp_filename, numSlices)
+ def reader_func(temp_filename: str) -> JavaObject:
+ return cast(JVMView, self._jvm).PythonRDD.readRDDFromFile(
+ self._jsc, temp_filename, numSlices
+ )
- def createRDDServer():
- return self._jvm.PythonParallelizeServer(self._jsc.sc(), numSlices)
+ def createRDDServer() -> JavaObject:
+ return cast(JVMView,
self._jvm).PythonParallelizeServer(self._jsc.sc(), numSlices)
jrdd = self._serialize_to_jvm(c, serializer, reader_func,
createRDDServer)
return RDD(jrdd, self, serializer)
- def _serialize_to_jvm(self, data, serializer, reader_func,
createRDDServer):
+ def _serialize_to_jvm(
+ self, data: Any, serializer: Serializer, reader_func: Callable,
createRDDServer: Callable
Review comment:
We might try to make it more specific. For example we know that
`createRDDServer` is nullary, and should be consistent with
https://github.com/apache/spark/blob/ef4f2546c58ef5fe67be7047f9aa2a793519fd54/python/pyspark/context.py#L615-L616
and `reader_func` should be consistent with `reader_func`
##########
File path: python/pyspark/context.py
##########
@@ -787,8 +844,8 @@ def binaryRecords(self, path, recordLength):
"""
return RDD(self._jsc.binaryRecords(path, recordLength), self,
NoOpSerializer())
- def _dictToJavaMap(self, d):
- jm = self._jvm.java.util.HashMap()
+ def _dictToJavaMap(self, d: Optional[Dict[str, str]]) -> Any:
Review comment:
Return type should be `py4j.java_collections.JavaMap`
##########
File path: python/pyspark/ml/clustering.py
##########
@@ -17,6 +17,7 @@
import sys
import warnings
+from typing import cast
Review comment:
`pyspark.ml` is still covered with stubs files, so no `casts` should be
necessary at the moment. Let's revert this, and focus on necessary changes.
##########
File path: python/pyspark/ml/feature.py
##########
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+from typing import cast
Review comment:
Ditto
##########
File path: python/pyspark/ml/wrapper.py
##########
@@ -16,6 +16,7 @@
#
from abc import ABCMeta, abstractmethod
+from typing import cast
Review comment:
Ditto
##########
File path: python/pyspark/pandas/spark/functions.py
##########
@@ -40,7 +40,7 @@ def repeat(col: Column, n: Union[int, Column]) -> Column:
"""
Repeats a string column n times, and returns it as a new string column.
"""
- sc = SparkContext._active_spark_context # type: ignore[attr-defined]
+ sc = cast(SparkContext, SparkContext._active_spark_context) # type:
ignore[attr-defined]
Review comment:
We should be able to drop `type: ignore` here.
##########
File path: python/pyspark/rdd.py
##########
@@ -250,7 +251,10 @@ def __call__(self, k):
return self.partitionFunc(k) % self.numPartitions
-class RDD(object):
+T = TypeVar("T")
+
+
+class RDD(Generic[T], object):
Review comment:
That shouldn't be necessary for this task.
##########
File path: python/pyspark/sql/avro/functions.py
##########
@@ -73,7 +73,7 @@ def from_avro(
[Row(value=Row(avro=Row(age=2, name='Alice')))]
"""
- sc = SparkContext._active_spark_context # type: ignore[attr-defined]
+ sc = cast(SparkContext, SparkContext._active_spark_context) # type:
ignore[attr-defined]
Review comment:
To not to repeat this every time ‒ we should be able to drop `type:
ignore[attr-defined]` from `SparkContext._active_spark_context`
##########
File path: python/pyspark/sql/context.py
##########
@@ -101,6 +101,8 @@ class SQLContext(object):
"""
_instantiatedContext: ClassVar[Optional["SQLContext"]] = None
+ _jsqlContext: JavaObject
+ _jvm: JavaObject
Review comment:
Could you clarify why we are adding these?
##########
File path: python/pyspark/sql/pandas/conversion.py
##########
@@ -583,7 +583,7 @@ def create_RDD_server():
jrdd = self._sc._serialize_to_jvm( # type: ignore[attr-defined]
arrow_data, ser, reader_func, create_RDD_server
)
- jdf = self._jvm.PythonSQLUtils.toDataFrame( # type:
ignore[attr-defined]
+ jdf = self._jvm.PythonSQLUtils.toDataFrame( # type: ignore[has-type]
Review comment:
Could you elaborate on this change, please?
##########
File path: python/pyspark/sql/session.py
##########
@@ -137,6 +137,9 @@ class SparkSession(SparkConversionMixin):
[(1, 'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1,
[1, 2, 3])]
"""
+ _jsparkSession: JavaObject
+ _wrapped: "SQLContext"
Review comment:
Why do we add these?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]