This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 208465c3ca3 [SPARK-41591][PYTHON][ML][FOLLOW-UP] Fix type hints that
are incompatible with Python <= 3.8
208465c3ca3 is described below
commit 208465c3ca354d5c6d6f6e11d5c00e80159acdf6
Author: harupy <[email protected]>
AuthorDate: Fri Jan 13 12:37:21 2023 +0900
[SPARK-41591][PYTHON][ML][FOLLOW-UP] Fix type hints that are incompatible
with Python <= 3.8
### What changes were proposed in this pull request?
This PR fixes type hints (added in
https://github.com/apache/spark/pull/39188) that are incompatible with Python
<= 3.8.
### Why are the changes needed?
Because PySpark still supports Python <= 3.8.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing unit tests
Closes #39542 from harupy/fix-type-hints.
Authored-by: harupy <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/ml/torch/distributor.py | 14 +++++++-------
python/pyspark/ml/torch/tests/test_distributor.py | 10 +++++-----
2 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/python/pyspark/ml/torch/distributor.py
b/python/pyspark/ml/torch/distributor.py
index 80d5ad31c3c..3a59692cd12 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -25,7 +25,7 @@ import signal
import sys
import subprocess
import time
-from typing import Union, Callable, Optional, Any
+from typing import Union, Callable, List, Dict, Optional, Any
import warnings
from pyspark.sql import SparkSession
@@ -73,7 +73,7 @@ def get_conf_boolean(sc: SparkContext, key: str,
default_value: str) -> bool:
)
-def get_gpus_owned(sc: SparkContext) -> list[str]:
+def get_gpus_owned(sc: SparkContext) -> List[str]:
"""Gets the number of GPUs that Spark scheduled to the calling task.
Parameters
@@ -130,7 +130,7 @@ class Distributor:
self.num_tasks = self._get_num_tasks()
self.ssl_conf = None
- def _create_input_params(self) -> dict[str, Any]:
+ def _create_input_params(self) -> Dict[str, Any]:
input_params = self.__dict__.copy()
for unneeded_param in ["spark", "sc", "ssl_conf"]:
del input_params[unneeded_param]
@@ -316,8 +316,8 @@ class TorchDistributor(Distributor):
@staticmethod
def _create_torchrun_command(
- input_params: dict[str, Any], path_to_train_file: str, *args: Any
- ) -> list[str]:
+ input_params: Dict[str, Any], path_to_train_file: str, *args: Any
+ ) -> List[str]:
local_mode = input_params["local_mode"]
num_processes = input_params["num_processes"]
@@ -339,7 +339,7 @@ class TorchDistributor(Distributor):
@staticmethod
def _execute_command(
- cmd: list[str], _prctl: bool = True, redirect_to_stdout: bool = True
+ cmd: List[str], _prctl: bool = True, redirect_to_stdout: bool = True
) -> None:
_TAIL_LINES_TO_KEEP = 100
@@ -430,7 +430,7 @@ class TorchDistributor(Distributor):
@staticmethod
def _run_training_on_pytorch_file(
- input_params: dict[str, Any], train_path: str, *args: Any
+ input_params: Dict[str, Any], train_path: str, *args: Any
) -> None:
training_command = TorchDistributor._create_torchrun_command(
input_params, train_path, *args
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py
b/python/pyspark/ml/torch/tests/test_distributor.py
index 4b24eff8742..9f57024cc4e 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -24,7 +24,7 @@ import sys
import time
import tempfile
import threading
-from typing import Callable
+from typing import Callable, Dict
import unittest
from unittest.mock import patch
@@ -56,11 +56,11 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase):
def tearDown(self) -> None:
self.spark.stop()
- def setup_env_vars(self, input_map: dict[str, str]) -> None:
+ def setup_env_vars(self, input_map: Dict[str, str]) -> None:
for key, value in input_map.items():
os.environ[key] = value
- def delete_env_vars(self, input_map: dict[str, str]) -> None:
+ def delete_env_vars(self, input_map: Dict[str, str]) -> None:
for key in input_map.keys():
del os.environ[key]
@@ -196,11 +196,11 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
os.unlink(self.tempFile.name)
self.spark.stop()
- def setup_env_vars(self, input_map: dict[str, str]) -> None:
+ def setup_env_vars(self, input_map: Dict[str, str]) -> None:
for key, value in input_map.items():
os.environ[key] = value
- def delete_env_vars(self, input_map: dict[str, str]) -> None:
+ def delete_env_vars(self, input_map: Dict[str, str]) -> None:
for key in input_map.keys():
del os.environ[key]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]