This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6108364  [SPARK-37296][PYTHON] Add missing type hints in 
python/pyspark/util.py
6108364 is described below

commit 6108364a67ee3dac22440e4af1aa8da14f505917
Author: Takuya UESHIN <[email protected]>
AuthorDate: Fri Nov 12 12:44:48 2021 +0900

    [SPARK-37296][PYTHON] Add missing type hints in python/pyspark/util.py
    
    ### What changes were proposed in this pull request?
    
    Adds missing type hints in `python/pyspark/util.py`.
    
    ### Why are the changes needed?
    
    Some of type hints in `python/pyspark/util.py` are missing, but the error 
was ignored by setting `disallow_untyped_defs = False` in `mypy.ini`. We should 
remove the setting and add the missing type hints.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    `lint-python` should pass.
    
    Closes #34563 from ueshin/issues/SPARK-37296/util.
    
    Authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/mypy.ini        |  3 ---
 python/pyspark/util.py | 70 ++++++++++++++++++++++++++++----------------------
 2 files changed, 39 insertions(+), 34 deletions(-)

diff --git a/python/mypy.ini b/python/mypy.ini
index eb29109..a89a750 100644
--- a/python/mypy.ini
+++ b/python/mypy.ini
@@ -95,9 +95,6 @@ disallow_untyped_defs = False
 [mypy-pyspark.traceback_utils]
 disallow_untyped_defs = False
 
-[mypy-pyspark.util]
-disallow_untyped_defs = False
-
 [mypy-pyspark.worker]
 disallow_untyped_defs = False
 
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 052dc65..7b43a19 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -24,19 +24,15 @@ import re
 import sys
 import threading
 import traceback
-import types
+from types import TracebackType
+from typing import Any, Callable, Iterator, List, Optional, TextIO, Tuple
 
-try:
-    from collections.abc import Callable
-except AttributeError:
-    from collections import Callable
+from py4j.clientserver import ClientServer  # type: ignore[import]
 
-from py4j.clientserver import ClientServer
+__all__: List[str] = []
 
-__all__ = []  # type: ignore
 
-
-def print_exec(stream):
+def print_exec(stream: TextIO) -> None:
     ei = sys.exc_info()
     traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
 
@@ -47,7 +43,7 @@ class VersionUtils(object):
     """
 
     @staticmethod
-    def majorMinorVersion(sparkVersion: str):
+    def majorMinorVersion(sparkVersion: str) -> Tuple[int, int]:
         """
         Given a Spark version string, return the (major version number, minor 
version number).
         E.g., for 2.0.1-SNAPSHOT, return (2, 0).
@@ -72,13 +68,13 @@ class VersionUtils(object):
             )
 
 
-def fail_on_stopiteration(f):
+def fail_on_stopiteration(f: Callable) -> Callable:
     """
     Wraps the input function to fail on 'StopIteration' by raising a 
'RuntimeError'
     prevents silent loss of data when 'f' is used in a for loop in Spark code
     """
 
-    def wrapper(*args, **kwargs):
+    def wrapper(*args: Any, **kwargs: Any) -> Any:
         try:
             return f(*args, **kwargs)
         except StopIteration as exc:
@@ -89,13 +85,13 @@ def fail_on_stopiteration(f):
     return wrapper
 
 
-def walk_tb(tb):
+def walk_tb(tb: Optional[TracebackType]) -> Iterator[TracebackType]:
     while tb is not None:
         yield tb
         tb = tb.tb_next
 
 
-def try_simplify_traceback(tb):
+def try_simplify_traceback(tb: TracebackType) -> Optional[TracebackType]:
     """
     Simplify the traceback. It removes the tracebacks in the current package, 
and only
     shows the traceback that is related to the thirdparty and user-specified 
codes.
@@ -218,7 +214,7 @@ def try_simplify_traceback(tb):
 
     for cur_tb, cur_frame in reversed(list(itertools.chain(last_seen, pairs))):
         # Once we have seen the file names outside, don't skip.
-        new_tb = types.TracebackType(
+        new_tb = TracebackType(
             tb_next=tb_next,
             tb_frame=cur_tb.tb_frame,
             tb_lasti=cur_tb.tb_frame.f_lasti,
@@ -228,7 +224,7 @@ def try_simplify_traceback(tb):
     return new_tb
 
 
-def _print_missing_jar(lib_name, pkg_name, jar_name, spark_version):
+def _print_missing_jar(lib_name: str, pkg_name: str, jar_name: str, 
spark_version: str) -> None:
     print(
         """
 
________________________________________________________________________________________________
@@ -258,7 +254,7 @@ 
________________________________________________________________________________
     )
 
 
-def _parse_memory(s):
+def _parse_memory(s: str) -> int:
     """
     Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
     return the value in MiB
@@ -335,10 +331,12 @@ def inheritable_thread_target(f: Callable) -> Callable:
         )
 
         @functools.wraps(f)
-        def wrapped(*args, **kwargs):
+        def wrapped(*args: Any, **kwargs: Any) -> Any:
             try:
                 # Set local properties in child thread.
-                
SparkContext._active_spark_context._jsc.sc().setLocalProperties(properties)
+                
SparkContext._active_spark_context._jsc.sc().setLocalProperties(  # type: 
ignore[attr-defined]
+                    properties
+                )
                 return f(*args, **kwargs)
             finally:
                 InheritableThread._clean_py4j_conn_for_current_thread()
@@ -369,39 +367,49 @@ class InheritableThread(threading.Thread):
     This API is experimental.
     """
 
-    def __init__(self, target, *args, **kwargs):
+    def __init__(self, target: Callable, *args: Any, **kwargs: Any):
         from pyspark import SparkContext
 
-        if isinstance(SparkContext._gateway, ClientServer):
+        if isinstance(SparkContext._gateway, ClientServer):  # type: 
ignore[attr-defined]
             # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
-            def copy_local_properties(*a, **k):
+            def copy_local_properties(*a: Any, **k: Any) -> Any:
                 # self._props is set before starting the thread to match the 
behavior with JVM.
                 assert hasattr(self, "_props")
-                
SparkContext._active_spark_context._jsc.sc().setLocalProperties(self._props)
+                
SparkContext._active_spark_context._jsc.sc().setLocalProperties(  # type: 
ignore[attr-defined]
+                    self._props
+                )
                 try:
                     return target(*a, **k)
                 finally:
                     InheritableThread._clean_py4j_conn_for_current_thread()
 
-            super(InheritableThread, 
self).__init__(target=copy_local_properties, *args, **kwargs)
+            super(InheritableThread, self).__init__(
+                target=copy_local_properties, *args, **kwargs  # type: 
ignore[misc]
+            )
         else:
-            super(InheritableThread, self).__init__(target=target, *args, 
**kwargs)
+            super(InheritableThread, self).__init__(
+                target=target, *args, **kwargs  # type: ignore[misc]
+            )
 
-    def start(self, *args, **kwargs):
+    def start(self) -> None:
         from pyspark import SparkContext
 
-        if isinstance(SparkContext._gateway, ClientServer):
+        if isinstance(SparkContext._gateway, ClientServer):  # type: 
ignore[attr-defined]
             # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
 
             # Local property copy should happen in Thread.start to mimic JVM's 
behavior.
-            self._props = 
SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone()
-        return super(InheritableThread, self).start(*args, **kwargs)
+            self._props = (
+                SparkContext._active_spark_context._jsc.sc()  # type: 
ignore[attr-defined]
+                .getLocalProperties()
+                .clone()
+            )
+        return super(InheritableThread, self).start()
 
     @staticmethod
-    def _clean_py4j_conn_for_current_thread():
+    def _clean_py4j_conn_for_current_thread() -> None:
         from pyspark import SparkContext
 
-        jvm = SparkContext._jvm
+        jvm = SparkContext._jvm  # type: ignore[attr-defined]
         thread_connection = jvm._gateway_client.get_thread_connection()
         if thread_connection is not None:
             try:

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to