zero323 commented on a change in pull request #34466:
URL: https://github.com/apache/spark/pull/34466#discussion_r753785413



##########
File path: python/pyspark/accumulators.py
##########
@@ -46,7 +48,10 @@ def _deserialize_accumulator(aid, zero_value, accum_param):
         return accum
 
 
-class Accumulator(object):
+T = TypeVar("T")

Review comment:
       Let's not touch things which are not strictly related to this particular 
PR. Not to mention we should never have annotations in both `py` and `pyi` 
files.

##########
File path: python/pyspark/broadcast.py
##########
@@ -42,7 +43,10 @@ def _from_id(bid):
     return _broadcastRegistry[bid]
 
 
-class Broadcast(object):
+T = TypeVar("T")
+
+
+class Broadcast(Generic[T], object):

Review comment:
       Ditto

##########
File path: python/pyspark/context.py
##########
@@ -125,30 +144,36 @@ class SparkContext(object):
     ValueError: ...
     """
 
-    _gateway = None
-    _jvm = None
+    _gateway: JavaGateway = None
+    _jvm: JVMView = None
     _next_accum_id = 0
-    _active_spark_context = None
+    _active_spark_context: Optional["SparkContext"] = None
     _lock = RLock()
-    _python_includes = None  # zip and egg files that need to be added to 
PYTHONPATH
-
+    _python_includes: Optional[
+        List[str]
+    ] = None  # zip and egg files that need to be added to PYTHONPATH
+    profiler_collector: Optional[ProfilerCollector]
+    serializer: Serializer

Review comment:
       These should be marked as `ClassVar`.

##########
File path: python/pyspark/context.py
##########
@@ -180,20 +205,20 @@ def __init__(
 
     def _do_init(
         self,
-        master,
-        appName,
-        sparkHome,
-        pyFiles,
-        environment,
-        batchSize,
-        serializer,
-        conf,
-        jsc,
-        profiler_cls,
-    ):
+        master: Optional[str],
+        appName: Optional[str],
+        sparkHome: Optional[str],
+        pyFiles: Optional[List[str]],
+        environment: Optional[Dict[str, str]],
+        batchSize: int,
+        serializer: Serializer,
+        conf: Optional[SparkConf],
+        jsc: JavaObject,
+        profiler_cls: type,
+    ) -> Any:

Review comment:
       `_do_init` is used only for side effects and returns nothing, so should 
be `-> None`.

##########
File path: python/pyspark/context.py
##########
@@ -221,7 +246,7 @@ def _do_init(
         if environment:
             for key, value in environment.items():
                 self._conf.setExecutorEnv(key, value)
-        for key, value in DEFAULT_CONFIGS.items():
+        for key, value in DEFAULT_CONFIGS.items():  # type: ignore[assignment]

Review comment:
       Instead of ignoring here, we can add 
   
   ```str
   DEFAULT_CONFIGS: Dict[str, Any] = {
   ```
   
   and, for consistency, adjust `environment` in `__init__` and `_do__init` as 
well.

##########
File path: python/pyspark/context.py
##########
@@ -244,26 +269,32 @@ def _do_init(
         self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED", 
"0")
 
         # Create the Java SparkContext through Py4J
-        self._jsc = jsc or self._initialize_context(self._conf._jconf)
+        self._jsc = jsc or self._initialize_context(self._conf._jconf)  # 
type: ignore[attr-defined]

Review comment:
       We can add this to the corresponding stub file, so we can skip ignore, 
but I won't insist on that.

##########
File path: python/pyspark/context.py
##########
@@ -244,26 +269,32 @@ def _do_init(
         self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED", 
"0")
 
         # Create the Java SparkContext through Py4J
-        self._jsc = jsc or self._initialize_context(self._conf._jconf)
+        self._jsc = jsc or self._initialize_context(self._conf._jconf)  # 
type: ignore[attr-defined]
         # Reset the SparkConf to the one actually used by the SparkContext in 
JVM.
         self._conf = SparkConf(_jconf=self._jsc.sc().conf())
 
         # Create a single Accumulator in Java that we'll send all our updates 
through;
         # they will be passed back to us through a TCP server
-        auth_token = self._gateway.gateway_parameters.auth_token
-        self._accumulatorServer = accumulators._start_update_server(auth_token)
+        auth_token = cast(JVMView, self._gateway).gateway_parameters.auth_token
+        self._accumulatorServer = accumulators._start_update_server(  # type: 
ignore[attr-defined]

Review comment:
       Ditto

##########
File path: python/pyspark/context.py
##########
@@ -244,26 +269,32 @@ def _do_init(
         self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED", 
"0")
 
         # Create the Java SparkContext through Py4J
-        self._jsc = jsc or self._initialize_context(self._conf._jconf)
+        self._jsc = jsc or self._initialize_context(self._conf._jconf)  # 
type: ignore[attr-defined]
         # Reset the SparkConf to the one actually used by the SparkContext in 
JVM.
         self._conf = SparkConf(_jconf=self._jsc.sc().conf())
 
         # Create a single Accumulator in Java that we'll send all our updates 
through;
         # they will be passed back to us through a TCP server
-        auth_token = self._gateway.gateway_parameters.auth_token
-        self._accumulatorServer = accumulators._start_update_server(auth_token)
+        auth_token = cast(JVMView, self._gateway).gateway_parameters.auth_token
+        self._accumulatorServer = accumulators._start_update_server(  # type: 
ignore[attr-defined]
+            auth_token
+        )  # type: ignore[attr-defined]

Review comment:
       Nit: We should need only one of these ignores.

##########
File path: python/pyspark/context.py
##########
@@ -244,26 +269,32 @@ def _do_init(
         self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED", 
"0")
 
         # Create the Java SparkContext through Py4J
-        self._jsc = jsc or self._initialize_context(self._conf._jconf)
+        self._jsc = jsc or self._initialize_context(self._conf._jconf)  # 
type: ignore[attr-defined]
         # Reset the SparkConf to the one actually used by the SparkContext in 
JVM.
         self._conf = SparkConf(_jconf=self._jsc.sc().conf())
 
         # Create a single Accumulator in Java that we'll send all our updates 
through;
         # they will be passed back to us through a TCP server
-        auth_token = self._gateway.gateway_parameters.auth_token
-        self._accumulatorServer = accumulators._start_update_server(auth_token)
+        auth_token = cast(JVMView, self._gateway).gateway_parameters.auth_token

Review comment:
       `_gateway` is `ClientServer` not `JVMView`.

##########
File path: python/pyspark/context.py
##########
@@ -244,26 +269,32 @@ def _do_init(
         self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED", 
"0")
 
         # Create the Java SparkContext through Py4J
-        self._jsc = jsc or self._initialize_context(self._conf._jconf)
+        self._jsc = jsc or self._initialize_context(self._conf._jconf)  # 
type: ignore[attr-defined]
         # Reset the SparkConf to the one actually used by the SparkContext in 
JVM.
         self._conf = SparkConf(_jconf=self._jsc.sc().conf())
 
         # Create a single Accumulator in Java that we'll send all our updates 
through;
         # they will be passed back to us through a TCP server
-        auth_token = self._gateway.gateway_parameters.auth_token
-        self._accumulatorServer = accumulators._start_update_server(auth_token)
+        auth_token = cast(JVMView, self._gateway).gateway_parameters.auth_token
+        self._accumulatorServer = accumulators._start_update_server(  # type: 
ignore[attr-defined]
+            auth_token
+        )  # type: ignore[attr-defined]
         (host, port) = self._accumulatorServer.server_address
-        self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, 
auth_token)
+        self._javaAccumulator = cast(JVMView, 
self._jvm).PythonAccumulatorV2(host, port, auth_token)

Review comment:
       Given how many of these casts we have in this scope, it might be better 
to extract a local variable
   
   ```python
   jvm = cast(JVMView, self._jvm)
   ```
   
   

##########
File path: python/pyspark/context.py
##########
@@ -324,21 +359,23 @@ def _do_init(
             self.profiler_collector = None
 
         # create a signal handler which would be invoked on receiving SIGINT
-        def signal_handler(signal, frame):
+        def signal_handler(signal: Any, frame: Any) -> None:
             self.cancelAllJobs()
             raise KeyboardInterrupt()

Review comment:
       Might be overkill, but you can take a look at
   
   
https://github.com/python/typeshed/blob/ee487304d76c671aa353a15e34bc1102bffa2362/stdlib/signal.pyi#L188
   
   and
   
   
https://github.com/python/typeshed/blob/ee487304d76c671aa353a15e34bc1102bffa2362/stdlib/signal.pyi#L83
   
   Also it might be `NoReturn` in our case.

##########
File path: python/pyspark/context.py
##########
@@ -324,21 +359,23 @@ def _do_init(
             self.profiler_collector = None
 
         # create a signal handler which would be invoked on receiving SIGINT
-        def signal_handler(signal, frame):
+        def signal_handler(signal: Any, frame: Any) -> None:

Review comment:
       Might be overkill, but you can take a look at
   
   
https://github.com/python/typeshed/blob/ee487304d76c671aa353a15e34bc1102bffa2362/stdlib/signal.pyi#L188
   
   and
   
   
https://github.com/python/typeshed/blob/ee487304d76c671aa353a15e34bc1102bffa2362/stdlib/signal.pyi#L83
   
   Also it might be `NoReturn` in our case.

##########
File path: python/pyspark/context.py
##########
@@ -324,21 +359,23 @@ def _do_init(
             self.profiler_collector = None
 
         # create a signal handler which would be invoked on receiving SIGINT
-        def signal_handler(signal, frame):
+        def signal_handler(signal: Any, frame: Any) -> None:
             self.cancelAllJobs()
             raise KeyboardInterrupt()
 
         # see http://stackoverflow.com/questions/23206787/
-        if isinstance(threading.current_thread(), threading._MainThread):
+        if isinstance(
+            threading.current_thread(), threading._MainThread  # type: 
ignore[attr-defined]
+        ):  # type: ignore[attr-defined]

Review comment:
       Only one of these should be necessary,

##########
File path: python/pyspark/context.py
##########
@@ -244,26 +269,32 @@ def _do_init(
         self.environment["PYTHONHASHSEED"] = os.environ.get("PYTHONHASHSEED", 
"0")
 
         # Create the Java SparkContext through Py4J
-        self._jsc = jsc or self._initialize_context(self._conf._jconf)
+        self._jsc = jsc or self._initialize_context(self._conf._jconf)  # 
type: ignore[attr-defined]
         # Reset the SparkConf to the one actually used by the SparkContext in 
JVM.
         self._conf = SparkConf(_jconf=self._jsc.sc().conf())
 
         # Create a single Accumulator in Java that we'll send all our updates 
through;
         # they will be passed back to us through a TCP server
-        auth_token = self._gateway.gateway_parameters.auth_token
-        self._accumulatorServer = accumulators._start_update_server(auth_token)
+        auth_token = cast(JVMView, self._gateway).gateway_parameters.auth_token

Review comment:
       `_gateway` is `ClientServer` or `JavaGateway`, not `JVMView`.

##########
File path: python/pyspark/context.py
##########
@@ -434,17 +481,17 @@ def getOrCreate(cls, conf=None):
         with SparkContext._lock:
             if SparkContext._active_spark_context is None:
                 SparkContext(conf=conf or SparkConf())
-            return SparkContext._active_spark_context
+            return cast("SparkContext", SparkContext._active_spark_context)

Review comment:
       I'd `assert  SparkContext is not None` here.

##########
File path: python/pyspark/context.py
##########
@@ -565,9 +614,9 @@ def range(self, start, end=None, step=1, numSlices=None):
             end = start
             start = 0
 
-        return self.parallelize(range(start, end, step), numSlices)
+        return self.parallelize(list(range(start, end, step)), numSlices)
 
-    def parallelize(self, c, numSlices=None):
+    def parallelize(self, c: List[T], numSlices: Optional[int] = None) -> 
RDD[T]:

Review comment:
       We did you change the signature here? 
   
   
https://github.com/apache/spark/blob/ef4f2546c58ef5fe67be7047f9aa2a793519fd54/python/pyspark/context.pyi#L100

##########
File path: python/pyspark/context.py
##########
@@ -565,9 +614,9 @@ def range(self, start, end=None, step=1, numSlices=None):
             end = start
             start = 0
 
-        return self.parallelize(range(start, end, step), numSlices)
+        return self.parallelize(list(range(start, end, step)), numSlices)
 
-    def parallelize(self, c, numSlices=None):
+    def parallelize(self, c: List[T], numSlices: Optional[int] = None) -> 
RDD[T]:

Review comment:
       Why did you change the signature here? 
   
   
https://github.com/apache/spark/blob/ef4f2546c58ef5fe67be7047f9aa2a793519fd54/python/pyspark/context.pyi#L100

##########
File path: python/pyspark/context.py
##########
@@ -587,10 +636,10 @@ def parallelize(self, c, numSlices=None):
             step = c[1] - c[0] if size > 1 else 1
             start0 = c[0]
 
-            def getStart(split):
+            def getStart(split: T) -> T:

Review comment:
       At first glance, this should be `(int) -> int`.

##########
File path: python/pyspark/context.py
##########
@@ -609,16 +658,20 @@ def f(split, iterator):
         batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
         serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
 
-        def reader_func(temp_filename):
-            return self._jvm.PythonRDD.readRDDFromFile(self._jsc, 
temp_filename, numSlices)
+        def reader_func(temp_filename: str) -> JavaObject:
+            return cast(JVMView, self._jvm).PythonRDD.readRDDFromFile(
+                self._jsc, temp_filename, numSlices
+            )
 
-        def createRDDServer():
-            return self._jvm.PythonParallelizeServer(self._jsc.sc(), numSlices)
+        def createRDDServer() -> JavaObject:
+            return cast(JVMView, 
self._jvm).PythonParallelizeServer(self._jsc.sc(), numSlices)
 
         jrdd = self._serialize_to_jvm(c, serializer, reader_func, 
createRDDServer)
         return RDD(jrdd, self, serializer)
 
-    def _serialize_to_jvm(self, data, serializer, reader_func, 
createRDDServer):
+    def _serialize_to_jvm(
+        self, data: Any, serializer: Serializer, reader_func: Callable, 
createRDDServer: Callable

Review comment:
       We might try to make it more specific. For example we know that 
`createRDDServer` is nullary, and should be consistent with
   
   
https://github.com/apache/spark/blob/ef4f2546c58ef5fe67be7047f9aa2a793519fd54/python/pyspark/context.py#L615-L616
   
   and `reader_func` should be consistent with `reader_func`

##########
File path: python/pyspark/context.py
##########
@@ -787,8 +844,8 @@ def binaryRecords(self, path, recordLength):
         """
         return RDD(self._jsc.binaryRecords(path, recordLength), self, 
NoOpSerializer())
 
-    def _dictToJavaMap(self, d):
-        jm = self._jvm.java.util.HashMap()
+    def _dictToJavaMap(self, d: Optional[Dict[str, str]]) -> Any:

Review comment:
       Return type should be `py4j.java_collections.JavaMap`

##########
File path: python/pyspark/ml/clustering.py
##########
@@ -17,6 +17,7 @@
 
 import sys
 import warnings
+from typing import cast

Review comment:
       `pyspark.ml` is still covered with stubs files, so no `casts` should be 
necessary at the moment. Let's revert this, and focus on necessary changes.

##########
File path: python/pyspark/ml/feature.py
##########
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+from typing import cast

Review comment:
       Ditto

##########
File path: python/pyspark/ml/wrapper.py
##########
@@ -16,6 +16,7 @@
 #
 
 from abc import ABCMeta, abstractmethod
+from typing import cast

Review comment:
       Ditto

##########
File path: python/pyspark/pandas/spark/functions.py
##########
@@ -40,7 +40,7 @@ def repeat(col: Column, n: Union[int, Column]) -> Column:
     """
     Repeats a string column n times, and returns it as a new string column.
     """
-    sc = SparkContext._active_spark_context  # type: ignore[attr-defined]
+    sc = cast(SparkContext, SparkContext._active_spark_context)  # type: 
ignore[attr-defined]

Review comment:
       We should be able to drop `type: ignore` here.

##########
File path: python/pyspark/rdd.py
##########
@@ -250,7 +251,10 @@ def __call__(self, k):
         return self.partitionFunc(k) % self.numPartitions
 
 
-class RDD(object):
+T = TypeVar("T")
+
+
+class RDD(Generic[T], object):

Review comment:
       That shouldn't be necessary for this task.

##########
File path: python/pyspark/sql/avro/functions.py
##########
@@ -73,7 +73,7 @@ def from_avro(
     [Row(value=Row(avro=Row(age=2, name='Alice')))]
     """
 
-    sc = SparkContext._active_spark_context  # type: ignore[attr-defined]
+    sc = cast(SparkContext, SparkContext._active_spark_context)  # type: 
ignore[attr-defined]

Review comment:
       To not to repeat this every time ‒ we should be able to drop `type: 
ignore[attr-defined]` from `SparkContext._active_spark_context`

##########
File path: python/pyspark/sql/context.py
##########
@@ -101,6 +101,8 @@ class SQLContext(object):
     """
 
     _instantiatedContext: ClassVar[Optional["SQLContext"]] = None
+    _jsqlContext: JavaObject
+    _jvm: JavaObject

Review comment:
       Could you clarify why  we are adding these?

##########
File path: python/pyspark/sql/pandas/conversion.py
##########
@@ -583,7 +583,7 @@ def create_RDD_server():
         jrdd = self._sc._serialize_to_jvm(  # type: ignore[attr-defined]
             arrow_data, ser, reader_func, create_RDD_server
         )
-        jdf = self._jvm.PythonSQLUtils.toDataFrame(  # type: 
ignore[attr-defined]
+        jdf = self._jvm.PythonSQLUtils.toDataFrame(  # type: ignore[has-type]

Review comment:
       Could you elaborate on this change, please?

##########
File path: python/pyspark/sql/session.py
##########
@@ -137,6 +137,9 @@ class SparkSession(SparkConversionMixin):
     [(1, 'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, 
[1, 2, 3])]
     """
 
+    _jsparkSession: JavaObject
+    _wrapped: "SQLContext"

Review comment:
       Why do we add these?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to