This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e81c2c895a6 [SPARK-42183][PYTHON][ML][TESTS] Exclude 
pyspark.ml.torch.tests in MyPy tests
e81c2c895a6 is described below

commit e81c2c895a66f00888198f615a1f11b410bb7475
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Wed Jan 25 22:42:54 2023 +0900

    [SPARK-42183][PYTHON][ML][TESTS] Exclude pyspark.ml.torch.tests in MyPy 
tests
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to exclude `pyspark.ml.torch.tests` in MyPy tests
    
    ### Why are the changes needed?
    
    Initial intention was to annotate types for public APIs only, see also 
https://github.com/apache/spark/pull/38991
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, test-only.
    
    ### How was this patch tested?
    
    CI in this PR should test it out.
    
    Closes #39740 from HyukjinKwon/SPARK-42183.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/mypy.ini                                      |  3 +++
 python/pyspark/ml/torch/tests/test_distributor.py    | 20 ++++++++++----------
 .../pyspark/ml/torch/tests/test_log_communication.py | 18 +++++++-----------
 3 files changed, 20 insertions(+), 21 deletions(-)

diff --git a/python/mypy.ini b/python/mypy.ini
index 5f662a4a237..a845cd88bd8 100644
--- a/python/mypy.ini
+++ b/python/mypy.ini
@@ -85,6 +85,9 @@ disallow_untyped_defs = False
 [mypy-pyspark.ml.tests.*]
 ignore_errors = True
 
+[mypy-pyspark.ml.torch.tests.*]
+ignore_errors = True
+
 [mypy-pyspark.mllib.tests.*]
 ignore_errors = True
 
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py 
b/python/pyspark/ml/torch/tests/test_distributor.py
index 0f4a4a23dc0..baf68757f67 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -18,7 +18,7 @@
 import contextlib
 import os
 import shutil
-from six import StringIO  # type: ignore
+from six import StringIO
 import stat
 import subprocess
 import sys
@@ -57,7 +57,7 @@ def patch_stdout() -> StringIO:
 def create_training_function(mnist_dir_path: str) -> Callable:
     import torch.nn as nn
     import torch.nn.functional as F
-    from torchvision import transforms, datasets  # type: ignore
+    from torchvision import transforms, datasets
 
     batch_size = 100
     num_epochs = 1
@@ -99,7 +99,7 @@ def create_training_function(mnist_dir_path: str) -> Callable:
 
         dist.init_process_group("gloo")
 
-        train_sampler = DistributedSampler(dataset=train_dataset)  # type: 
ignore
+        train_sampler = DistributedSampler(dataset=train_dataset)
         data_loader = torch.utils.data.DataLoader(
             train_dataset, batch_size=batch_size, sampler=train_sampler
         )
@@ -220,11 +220,11 @@ class 
TorchDistributorBaselineUnitTests(unittest.TestCase):
             )
 
         # include command in the exception message
-        with self.assertRaisesRegexp(RuntimeError, "exit 1"):  # pylint: 
disable=deprecated-method
+        with self.assertRaisesRegex(RuntimeError, "exit 1"):
             error_command = ["bash", "-c", "exit 1"]
             TorchDistributor._execute_command(error_command)
 
-        with self.assertRaisesRegexp(RuntimeError, "abcdef"):  # pylint: 
disable=deprecated-method
+        with self.assertRaisesRegex(RuntimeError, "abcdef"):
             error_command = ["bash", "-c", "'abc''def'"]
             TorchDistributor._execute_command(error_command)
 
@@ -359,7 +359,7 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
                     self.setup_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})
 
                 dist = TorchDistributor(num_processes, True, use_gpu)
-                dist._run_training_on_pytorch_file = lambda *args: 
os.environ.get(  # type: ignore
+                dist._run_training_on_pytorch_file = lambda *args: 
os.environ.get(
                     CUDA_VISIBLE_DEVICES, "NONE"
                 )
                 self.assertEqual(
@@ -429,7 +429,7 @@ class 
TorchDistributorDistributedUnitTests(unittest.TestCase):
         for i, (_, num_processes, use_gpu, expected) in enumerate(inputs):
             with self.subTest(f"subtest: {i + 1}"):
                 dist = TorchDistributor(num_processes, False, use_gpu)
-                dist._run_training_on_pytorch_file = lambda *args: 
os.environ.get(  # type: ignore
+                dist._run_training_on_pytorch_file = lambda *args: 
os.environ.get(
                     CUDA_VISIBLE_DEVICES, "NONE"
                 )
                 self.assertEqual(
@@ -486,14 +486,14 @@ class TorchWrapperUnitTests(unittest.TestCase):
         t = threading.Thread(target=check_parent_alive, args=(task,), 
daemon=True)
         t.start()
         time.sleep(2)
-        self.assertEqual(mock_clean_and_terminate.call_count, 0)  # type: 
ignore[attr-defined]
+        self.assertEqual(mock_clean_and_terminate.call_count, 0)
 
 
 if __name__ == "__main__":
-    from pyspark.ml.torch.tests.test_distributor import *  # noqa: F401,F403 
type: ignore
+    from pyspark.ml.torch.tests.test_distributor import *  # noqa: F401,F403
 
     try:
-        import xmlrunner  # type: ignore
+        import xmlrunner
 
         testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
     except ImportError:
diff --git a/python/pyspark/ml/torch/tests/test_log_communication.py 
b/python/pyspark/ml/torch/tests/test_log_communication.py
index 0c937926480..164c7556d12 100644
--- a/python/pyspark/ml/torch/tests/test_log_communication.py
+++ b/python/pyspark/ml/torch/tests/test_log_communication.py
@@ -18,14 +18,14 @@
 from __future__ import absolute_import, division, print_function
 
 import contextlib
-from six import StringIO  # type: ignore
+from six import StringIO
 import sys
 import time
 from typing import Any, Callable
 import unittest
 
 import pyspark.ml.torch.log_communication
-from pyspark.ml.torch.log_communication import (  # type: ignore
+from pyspark.ml.torch.log_communication import (
     LogStreamingServer,
     LogStreamingClient,
     LogStreamingClientBase,
@@ -47,15 +47,11 @@ def patch_stderr() -> StringIO:
 
 class LogStreamingServiceTestCase(unittest.TestCase):
     def setUp(self) -> None:
-        self.default_truncate_msg_len = (
-            pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN  # type: 
ignore
-        )
-        pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = 10  # type: 
ignore
+        self.default_truncate_msg_len = 
pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN
+        pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = 10
 
     def tearDown(self) -> None:
-        pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = (  # type: 
ignore
-            self.default_truncate_msg_len
-        )
+        pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = 
self.default_truncate_msg_len
 
     def basic_test(self) -> None:
         server = LogStreamingServer()
@@ -165,10 +161,10 @@ class LogStreamingServiceTestCase(unittest.TestCase):
 
 
 if __name__ == "__main__":
-    from pyspark.ml.torch.tests.test_log_communication import *  # noqa: 
F401,F403 type: ignore
+    from pyspark.ml.torch.tests.test_log_communication import *  # noqa: 
F401,F403
 
     try:
-        import xmlrunner  # type: ignore
+        import xmlrunner
 
         testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
     except ImportError:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to