This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e81c2c895a6 [SPARK-42183][PYTHON][ML][TESTS] Exclude
pyspark.ml.torch.tests in MyPy tests
e81c2c895a6 is described below
commit e81c2c895a66f00888198f615a1f11b410bb7475
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Wed Jan 25 22:42:54 2023 +0900
[SPARK-42183][PYTHON][ML][TESTS] Exclude pyspark.ml.torch.tests in MyPy
tests
### What changes were proposed in this pull request?
This PR proposes to exclude `pyspark.ml.torch.tests` in MyPy tests
### Why are the changes needed?
Initial intention was to annotate types for public APIs only, see also
https://github.com/apache/spark/pull/38991
### Does this PR introduce _any_ user-facing change?
No, test-only.
### How was this patch tested?
CI in this PR should test it out.
Closes #39740 from HyukjinKwon/SPARK-42183.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/mypy.ini | 3 +++
python/pyspark/ml/torch/tests/test_distributor.py | 20 ++++++++++----------
.../pyspark/ml/torch/tests/test_log_communication.py | 18 +++++++-----------
3 files changed, 20 insertions(+), 21 deletions(-)
diff --git a/python/mypy.ini b/python/mypy.ini
index 5f662a4a237..a845cd88bd8 100644
--- a/python/mypy.ini
+++ b/python/mypy.ini
@@ -85,6 +85,9 @@ disallow_untyped_defs = False
[mypy-pyspark.ml.tests.*]
ignore_errors = True
+[mypy-pyspark.ml.torch.tests.*]
+ignore_errors = True
+
[mypy-pyspark.mllib.tests.*]
ignore_errors = True
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py
b/python/pyspark/ml/torch/tests/test_distributor.py
index 0f4a4a23dc0..baf68757f67 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -18,7 +18,7 @@
import contextlib
import os
import shutil
-from six import StringIO # type: ignore
+from six import StringIO
import stat
import subprocess
import sys
@@ -57,7 +57,7 @@ def patch_stdout() -> StringIO:
def create_training_function(mnist_dir_path: str) -> Callable:
import torch.nn as nn
import torch.nn.functional as F
- from torchvision import transforms, datasets # type: ignore
+ from torchvision import transforms, datasets
batch_size = 100
num_epochs = 1
@@ -99,7 +99,7 @@ def create_training_function(mnist_dir_path: str) -> Callable:
dist.init_process_group("gloo")
- train_sampler = DistributedSampler(dataset=train_dataset) # type:
ignore
+ train_sampler = DistributedSampler(dataset=train_dataset)
data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler
)
@@ -220,11 +220,11 @@ class
TorchDistributorBaselineUnitTests(unittest.TestCase):
)
# include command in the exception message
- with self.assertRaisesRegexp(RuntimeError, "exit 1"): # pylint:
disable=deprecated-method
+ with self.assertRaisesRegex(RuntimeError, "exit 1"):
error_command = ["bash", "-c", "exit 1"]
TorchDistributor._execute_command(error_command)
- with self.assertRaisesRegexp(RuntimeError, "abcdef"): # pylint:
disable=deprecated-method
+ with self.assertRaisesRegex(RuntimeError, "abcdef"):
error_command = ["bash", "-c", "'abc''def'"]
TorchDistributor._execute_command(error_command)
@@ -359,7 +359,7 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
self.setup_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})
dist = TorchDistributor(num_processes, True, use_gpu)
- dist._run_training_on_pytorch_file = lambda *args:
os.environ.get( # type: ignore
+ dist._run_training_on_pytorch_file = lambda *args:
os.environ.get(
CUDA_VISIBLE_DEVICES, "NONE"
)
self.assertEqual(
@@ -429,7 +429,7 @@ class
TorchDistributorDistributedUnitTests(unittest.TestCase):
for i, (_, num_processes, use_gpu, expected) in enumerate(inputs):
with self.subTest(f"subtest: {i + 1}"):
dist = TorchDistributor(num_processes, False, use_gpu)
- dist._run_training_on_pytorch_file = lambda *args:
os.environ.get( # type: ignore
+ dist._run_training_on_pytorch_file = lambda *args:
os.environ.get(
CUDA_VISIBLE_DEVICES, "NONE"
)
self.assertEqual(
@@ -486,14 +486,14 @@ class TorchWrapperUnitTests(unittest.TestCase):
t = threading.Thread(target=check_parent_alive, args=(task,),
daemon=True)
t.start()
time.sleep(2)
- self.assertEqual(mock_clean_and_terminate.call_count, 0) # type:
ignore[attr-defined]
+ self.assertEqual(mock_clean_and_terminate.call_count, 0)
if __name__ == "__main__":
- from pyspark.ml.torch.tests.test_distributor import * # noqa: F401,F403
type: ignore
+ from pyspark.ml.torch.tests.test_distributor import * # noqa: F401,F403
try:
- import xmlrunner # type: ignore
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
except ImportError:
diff --git a/python/pyspark/ml/torch/tests/test_log_communication.py
b/python/pyspark/ml/torch/tests/test_log_communication.py
index 0c937926480..164c7556d12 100644
--- a/python/pyspark/ml/torch/tests/test_log_communication.py
+++ b/python/pyspark/ml/torch/tests/test_log_communication.py
@@ -18,14 +18,14 @@
from __future__ import absolute_import, division, print_function
import contextlib
-from six import StringIO # type: ignore
+from six import StringIO
import sys
import time
from typing import Any, Callable
import unittest
import pyspark.ml.torch.log_communication
-from pyspark.ml.torch.log_communication import ( # type: ignore
+from pyspark.ml.torch.log_communication import (
LogStreamingServer,
LogStreamingClient,
LogStreamingClientBase,
@@ -47,15 +47,11 @@ def patch_stderr() -> StringIO:
class LogStreamingServiceTestCase(unittest.TestCase):
def setUp(self) -> None:
- self.default_truncate_msg_len = (
- pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN # type:
ignore
- )
- pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = 10 # type:
ignore
+ self.default_truncate_msg_len =
pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN
+ pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = 10
def tearDown(self) -> None:
- pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN = ( # type:
ignore
- self.default_truncate_msg_len
- )
+ pyspark.ml.torch.log_communication._TRUNCATE_MSG_LEN =
self.default_truncate_msg_len
def basic_test(self) -> None:
server = LogStreamingServer()
@@ -165,10 +161,10 @@ class LogStreamingServiceTestCase(unittest.TestCase):
if __name__ == "__main__":
- from pyspark.ml.torch.tests.test_log_communication import * # noqa:
F401,F403 type: ignore
+ from pyspark.ml.torch.tests.test_log_communication import * # noqa:
F401,F403
try:
- import xmlrunner # type: ignore
+ import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
except ImportError:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]