This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6108364 [SPARK-37296][PYTHON] Add missing type hints in
python/pyspark/util.py
6108364 is described below
commit 6108364a67ee3dac22440e4af1aa8da14f505917
Author: Takuya UESHIN <[email protected]>
AuthorDate: Fri Nov 12 12:44:48 2021 +0900
[SPARK-37296][PYTHON] Add missing type hints in python/pyspark/util.py
### What changes were proposed in this pull request?
Adds missing type hints in `python/pyspark/util.py`.
### Why are the changes needed?
Some of type hints in `python/pyspark/util.py` are missing, but the error
was ignored by setting `disallow_untyped_defs = False` in `mypy.ini`. We should
remove the setting and add the missing type hints.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
`lint-python` should pass.
Closes #34563 from ueshin/issues/SPARK-37296/util.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/mypy.ini | 3 ---
python/pyspark/util.py | 70 ++++++++++++++++++++++++++++----------------------
2 files changed, 39 insertions(+), 34 deletions(-)
diff --git a/python/mypy.ini b/python/mypy.ini
index eb29109..a89a750 100644
--- a/python/mypy.ini
+++ b/python/mypy.ini
@@ -95,9 +95,6 @@ disallow_untyped_defs = False
[mypy-pyspark.traceback_utils]
disallow_untyped_defs = False
-[mypy-pyspark.util]
-disallow_untyped_defs = False
-
[mypy-pyspark.worker]
disallow_untyped_defs = False
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 052dc65..7b43a19 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -24,19 +24,15 @@ import re
import sys
import threading
import traceback
-import types
+from types import TracebackType
+from typing import Any, Callable, Iterator, List, Optional, TextIO, Tuple
-try:
- from collections.abc import Callable
-except AttributeError:
- from collections import Callable
+from py4j.clientserver import ClientServer # type: ignore[import]
-from py4j.clientserver import ClientServer
+__all__: List[str] = []
-__all__ = [] # type: ignore
-
-def print_exec(stream):
+def print_exec(stream: TextIO) -> None:
ei = sys.exc_info()
traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
@@ -47,7 +43,7 @@ class VersionUtils(object):
"""
@staticmethod
- def majorMinorVersion(sparkVersion: str):
+ def majorMinorVersion(sparkVersion: str) -> Tuple[int, int]:
"""
Given a Spark version string, return the (major version number, minor
version number).
E.g., for 2.0.1-SNAPSHOT, return (2, 0).
@@ -72,13 +68,13 @@ class VersionUtils(object):
)
-def fail_on_stopiteration(f):
+def fail_on_stopiteration(f: Callable) -> Callable:
"""
Wraps the input function to fail on 'StopIteration' by raising a
'RuntimeError'
prevents silent loss of data when 'f' is used in a for loop in Spark code
"""
- def wrapper(*args, **kwargs):
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
try:
return f(*args, **kwargs)
except StopIteration as exc:
@@ -89,13 +85,13 @@ def fail_on_stopiteration(f):
return wrapper
-def walk_tb(tb):
+def walk_tb(tb: Optional[TracebackType]) -> Iterator[TracebackType]:
while tb is not None:
yield tb
tb = tb.tb_next
-def try_simplify_traceback(tb):
+def try_simplify_traceback(tb: TracebackType) -> Optional[TracebackType]:
"""
Simplify the traceback. It removes the tracebacks in the current package,
and only
shows the traceback that is related to the thirdparty and user-specified
codes.
@@ -218,7 +214,7 @@ def try_simplify_traceback(tb):
for cur_tb, cur_frame in reversed(list(itertools.chain(last_seen, pairs))):
# Once we have seen the file names outside, don't skip.
- new_tb = types.TracebackType(
+ new_tb = TracebackType(
tb_next=tb_next,
tb_frame=cur_tb.tb_frame,
tb_lasti=cur_tb.tb_frame.f_lasti,
@@ -228,7 +224,7 @@ def try_simplify_traceback(tb):
return new_tb
-def _print_missing_jar(lib_name, pkg_name, jar_name, spark_version):
+def _print_missing_jar(lib_name: str, pkg_name: str, jar_name: str,
spark_version: str) -> None:
print(
"""
________________________________________________________________________________________________
@@ -258,7 +254,7 @@
________________________________________________________________________________
)
-def _parse_memory(s):
+def _parse_memory(s: str) -> int:
"""
Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
return the value in MiB
@@ -335,10 +331,12 @@ def inheritable_thread_target(f: Callable) -> Callable:
)
@functools.wraps(f)
- def wrapped(*args, **kwargs):
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
try:
# Set local properties in child thread.
-
SparkContext._active_spark_context._jsc.sc().setLocalProperties(properties)
+
SparkContext._active_spark_context._jsc.sc().setLocalProperties( # type:
ignore[attr-defined]
+ properties
+ )
return f(*args, **kwargs)
finally:
InheritableThread._clean_py4j_conn_for_current_thread()
@@ -369,39 +367,49 @@ class InheritableThread(threading.Thread):
This API is experimental.
"""
- def __init__(self, target, *args, **kwargs):
+ def __init__(self, target: Callable, *args: Any, **kwargs: Any):
from pyspark import SparkContext
- if isinstance(SparkContext._gateway, ClientServer):
+ if isinstance(SparkContext._gateway, ClientServer): # type:
ignore[attr-defined]
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
- def copy_local_properties(*a, **k):
+ def copy_local_properties(*a: Any, **k: Any) -> Any:
# self._props is set before starting the thread to match the
behavior with JVM.
assert hasattr(self, "_props")
-
SparkContext._active_spark_context._jsc.sc().setLocalProperties(self._props)
+
SparkContext._active_spark_context._jsc.sc().setLocalProperties( # type:
ignore[attr-defined]
+ self._props
+ )
try:
return target(*a, **k)
finally:
InheritableThread._clean_py4j_conn_for_current_thread()
- super(InheritableThread,
self).__init__(target=copy_local_properties, *args, **kwargs)
+ super(InheritableThread, self).__init__(
+ target=copy_local_properties, *args, **kwargs # type:
ignore[misc]
+ )
else:
- super(InheritableThread, self).__init__(target=target, *args,
**kwargs)
+ super(InheritableThread, self).__init__(
+ target=target, *args, **kwargs # type: ignore[misc]
+ )
- def start(self, *args, **kwargs):
+ def start(self) -> None:
from pyspark import SparkContext
- if isinstance(SparkContext._gateway, ClientServer):
+ if isinstance(SparkContext._gateway, ClientServer): # type:
ignore[attr-defined]
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
# Local property copy should happen in Thread.start to mimic JVM's
behavior.
- self._props =
SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone()
- return super(InheritableThread, self).start(*args, **kwargs)
+ self._props = (
+ SparkContext._active_spark_context._jsc.sc() # type:
ignore[attr-defined]
+ .getLocalProperties()
+ .clone()
+ )
+ return super(InheritableThread, self).start()
@staticmethod
- def _clean_py4j_conn_for_current_thread():
+ def _clean_py4j_conn_for_current_thread() -> None:
from pyspark import SparkContext
- jvm = SparkContext._jvm
+ jvm = SparkContext._jvm # type: ignore[attr-defined]
thread_connection = jvm._gateway_client.get_thread_connection()
if thread_connection is not None:
try:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]