This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 208465c3ca3 [SPARK-41591][PYTHON][ML][FOLLOW-UP] Fix type hints that 
are incompatible with Python <= 3.8
208465c3ca3 is described below

commit 208465c3ca354d5c6d6f6e11d5c00e80159acdf6
Author: harupy <[email protected]>
AuthorDate: Fri Jan 13 12:37:21 2023 +0900

    [SPARK-41591][PYTHON][ML][FOLLOW-UP] Fix type hints that are incompatible 
with Python <= 3.8
    
    ### What changes were proposed in this pull request?
    
    This PR fixes type hints (added in 
https://github.com/apache/spark/pull/39188) that are incompatible with Python 
<= 3.8.
    
    ### Why are the changes needed?
    
    Because PySpark still supports Python <= 3.8.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Existing unit tests
    
    Closes #39542 from harupy/fix-type-hints.
    
    Authored-by: harupy <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/ml/torch/distributor.py            | 14 +++++++-------
 python/pyspark/ml/torch/tests/test_distributor.py | 10 +++++-----
 2 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/python/pyspark/ml/torch/distributor.py 
b/python/pyspark/ml/torch/distributor.py
index 80d5ad31c3c..3a59692cd12 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -25,7 +25,7 @@ import signal
 import sys
 import subprocess
 import time
-from typing import Union, Callable, Optional, Any
+from typing import Union, Callable, List, Dict, Optional, Any
 import warnings
 
 from pyspark.sql import SparkSession
@@ -73,7 +73,7 @@ def get_conf_boolean(sc: SparkContext, key: str, 
default_value: str) -> bool:
     )
 
 
-def get_gpus_owned(sc: SparkContext) -> list[str]:
+def get_gpus_owned(sc: SparkContext) -> List[str]:
     """Gets the number of GPUs that Spark scheduled to the calling task.
 
     Parameters
@@ -130,7 +130,7 @@ class Distributor:
         self.num_tasks = self._get_num_tasks()
         self.ssl_conf = None
 
-    def _create_input_params(self) -> dict[str, Any]:
+    def _create_input_params(self) -> Dict[str, Any]:
         input_params = self.__dict__.copy()
         for unneeded_param in ["spark", "sc", "ssl_conf"]:
             del input_params[unneeded_param]
@@ -316,8 +316,8 @@ class TorchDistributor(Distributor):
 
     @staticmethod
     def _create_torchrun_command(
-        input_params: dict[str, Any], path_to_train_file: str, *args: Any
-    ) -> list[str]:
+        input_params: Dict[str, Any], path_to_train_file: str, *args: Any
+    ) -> List[str]:
         local_mode = input_params["local_mode"]
         num_processes = input_params["num_processes"]
 
@@ -339,7 +339,7 @@ class TorchDistributor(Distributor):
 
     @staticmethod
     def _execute_command(
-        cmd: list[str], _prctl: bool = True, redirect_to_stdout: bool = True
+        cmd: List[str], _prctl: bool = True, redirect_to_stdout: bool = True
     ) -> None:
         _TAIL_LINES_TO_KEEP = 100
 
@@ -430,7 +430,7 @@ class TorchDistributor(Distributor):
 
     @staticmethod
     def _run_training_on_pytorch_file(
-        input_params: dict[str, Any], train_path: str, *args: Any
+        input_params: Dict[str, Any], train_path: str, *args: Any
     ) -> None:
         training_command = TorchDistributor._create_torchrun_command(
             input_params, train_path, *args
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py 
b/python/pyspark/ml/torch/tests/test_distributor.py
index 4b24eff8742..9f57024cc4e 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -24,7 +24,7 @@ import sys
 import time
 import tempfile
 import threading
-from typing import Callable
+from typing import Callable, Dict
 import unittest
 from unittest.mock import patch
 
@@ -56,11 +56,11 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase):
     def tearDown(self) -> None:
         self.spark.stop()
 
-    def setup_env_vars(self, input_map: dict[str, str]) -> None:
+    def setup_env_vars(self, input_map: Dict[str, str]) -> None:
         for key, value in input_map.items():
             os.environ[key] = value
 
-    def delete_env_vars(self, input_map: dict[str, str]) -> None:
+    def delete_env_vars(self, input_map: Dict[str, str]) -> None:
         for key in input_map.keys():
             del os.environ[key]
 
@@ -196,11 +196,11 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
         os.unlink(self.tempFile.name)
         self.spark.stop()
 
-    def setup_env_vars(self, input_map: dict[str, str]) -> None:
+    def setup_env_vars(self, input_map: Dict[str, str]) -> None:
         for key, value in input_map.items():
             os.environ[key] = value
 
-    def delete_env_vars(self, input_map: dict[str, str]) -> None:
+    def delete_env_vars(self, input_map: Dict[str, str]) -> None:
         for key in input_map.keys():
             del os.environ[key]
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to