This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ad013d3eaea [SPARK-42993][ML][CONNECT] Make PyTorch Distributor
compatible with Spark Connect
ad013d3eaea is described below
commit ad013d3eaea85800c6428a89918c94994361e1cd
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Apr 7 08:15:52 2023 +0800
[SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark
Connect
### What changes were proposed in this pull request?
Make Torch Distributor support Spark Connect
### Why are the changes needed?
functionality parity.
**Note**, `local_mode` with `use_gpu` is not supported for now since
`sc.resources` is missing in Connect
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
reused UT
Closes #40607 from zhengruifeng/connect_torch.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../service/SparkConnectStreamHandler.scala | 19 +--
dev/sparktestsupport/modules.py | 1 +
.../tests/connect/test_parity_torch_distributor.py | 134 ++++++++++++++++++++
python/pyspark/ml/torch/distributor.py | 137 ++++++++++++---------
python/pyspark/ml/torch/log_communication.py | 8 +-
python/pyspark/ml/torch/tests/test_distributor.py | 114 +++++++++--------
6 files changed, 293 insertions(+), 120 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 96e16623222..8a14357640e 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -52,13 +52,18 @@ class SparkConnectStreamHandler(responseObserver:
StreamObserver[ExecutePlanResp
session.withActive {
// Add debug information to the query execution so that the jobs are
traceable.
- val debugString = v.toString
- session.sparkContext.setLocalProperty(
- "callSite.short",
- s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
- session.sparkContext.setLocalProperty(
- "callSite.long",
- StringUtils.abbreviate(debugString, 2048))
+ try {
+ val debugString = v.toString
+ session.sparkContext.setLocalProperty(
+ "callSite.short",
+ s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
+ session.sparkContext.setLocalProperty(
+ "callSite.long",
+ StringUtils.abbreviate(debugString, 2048))
+ } catch {
+ case e: Throwable =>
+ logWarning("Fail to extract or attach the debug information", e)
+ }
v.getPlan.getOpTypeCase match {
case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 1a28a644e55..249e2675b76 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -777,6 +777,7 @@ pyspark_connect = Module(
"pyspark.ml.connect.functions",
# ml unittests
"pyspark.ml.tests.connect.test_connect_function",
+ "pyspark.ml.tests.connect.test_parity_torch_distributor",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy,
pandas, and pyarrow and
diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
new file mode 100644
index 00000000000..8aa1079beaf
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
@@ -0,0 +1,134 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import tempfile
+import unittest
+
+have_torch = True
+try:
+ import torch # noqa: F401
+except ImportError:
+ have_torch = False
+
+from pyspark.sql import SparkSession
+
+from pyspark.ml.torch.distributor import TorchDistributor
+
+from pyspark.ml.torch.tests.test_distributor import (
+ TorchDistributorBaselineUnitTestsMixin,
+ TorchDistributorLocalUnitTestsMixin,
+ TorchDistributorDistributedUnitTestsMixin,
+ TorchWrapperUnitTestsMixin,
+)
+
+
[email protected](not have_torch, "torch is required")
+class TorchDistributorBaselineUnitTestsOnConnect(
+ TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
+):
+ def setUp(self) -> None:
+ self.spark = SparkSession.builder.remote("local[4]").getOrCreate()
+
+ def tearDown(self) -> None:
+ self.spark.stop()
+
+ def test_get_num_tasks_fails(self) -> None:
+ inputs = [1, 5, 4]
+
+ # This is when the conf isn't set and we request GPUs
+ for num_processes in inputs:
+ with self.subTest():
+ # TODO(SPARK-42994): Support sc.resources
+ # with self.assertRaisesRegex(RuntimeError, "driver"):
+ # TorchDistributor(num_processes, True, True)
+ with self.assertRaisesRegex(RuntimeError, "unset"):
+ TorchDistributor(num_processes, False, True)
+
+
[email protected](not have_torch, "torch is required")
+class TorchDistributorLocalUnitTestsOnConnect(
+ TorchDistributorLocalUnitTestsMixin, unittest.TestCase
+):
+ def setUp(self) -> None:
+ class_name = self.__class__.__name__
+ conf = self._get_spark_conf()
+ builder = SparkSession.builder.appName(class_name)
+ for k, v in conf.getAll():
+ if k not in ["spark.master", "spark.remote", "spark.app.name"]:
+ builder = builder.config(k, v)
+ self.spark = builder.remote("local-cluster[2,2,1024]").getOrCreate()
+ self.mnist_dir_path = tempfile.mkdtemp()
+
+ def tearDown(self) -> None:
+ shutil.rmtree(self.mnist_dir_path)
+ os.unlink(self.gpu_discovery_script_file.name)
+ self.spark.stop()
+
+ # TODO(SPARK-42994): Support sc.resources
+ @unittest.skip("need to support sc.resources")
+ def test_get_num_tasks_locally(self):
+ super().test_get_num_tasks_locally()
+
+ # TODO(SPARK-42994): Support sc.resources
+ @unittest.skip("need to support sc.resources")
+ def test_get_gpus_owned_local(self):
+ super().test_get_gpus_owned_local()
+
+ # TODO(SPARK-42994): Support sc.resources
+ @unittest.skip("need to support sc.resources")
+ def test_local_training_succeeds(self):
+ super().test_local_training_succeeds()
+
+
[email protected](not have_torch, "torch is required")
+class TorchDistributorDistributedUnitTestsOnConnect(
+ TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
+):
+ def setUp(self) -> None:
+ class_name = self.__class__.__name__
+ conf = self._get_spark_conf()
+ builder = SparkSession.builder.appName(class_name)
+ for k, v in conf.getAll():
+ if k not in ["spark.master", "spark.remote", "spark.app.name"]:
+ builder = builder.config(k, v)
+
+ self.spark = builder.remote("local-cluster[2,2,1024]").getOrCreate()
+ self.mnist_dir_path = tempfile.mkdtemp()
+
+ def tearDown(self) -> None:
+ shutil.rmtree(self.mnist_dir_path)
+ os.unlink(self.gpu_discovery_script_file.name)
+ self.spark.stop()
+
+
[email protected](not have_torch, "torch is required")
+class TorchWrapperUnitTestsOnConnect(TorchWrapperUnitTestsMixin,
unittest.TestCase):
+ pass
+
+
+if __name__ == "__main__":
+ from pyspark.ml.tests.connect.test_parity_torch_distributor import * #
noqa: F401,F403
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/torch/distributor.py
b/python/pyspark/ml/torch/distributor.py
index 157cc96717f..7491803aed2 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -32,28 +32,36 @@ from typing import Union, Callable, List, Dict, Optional,
Any, Tuple, Generator
from pyspark import cloudpickle
from pyspark.sql import SparkSession
+from pyspark.taskcontext import BarrierTaskContext
from pyspark.ml.torch.log_communication import ( # type: ignore
- get_driver_host,
LogStreamingClient,
LogStreamingServer,
)
-from pyspark.context import SparkContext
-from pyspark.taskcontext import BarrierTaskContext
-# TODO(SPARK-41589): will move the functions and tests to an external file
-# once we are in agreement about which functions should be in utils.py
-def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
- """Get the conf "key" from the given spark context,
+def _get_active_session() -> SparkSession:
+ from pyspark.sql.utils import is_remote
+
+ if not is_remote():
+ spark = SparkSession.getActiveSession()
+ else:
+ import pyspark.sql.connect.session
+
+ spark = pyspark.sql.connect.session._active_spark_session # type:
ignore[assignment]
+
+ if spark is None:
+ raise RuntimeError("An active SparkSession is required for the
distributor.")
+ return spark
+
+
+def _get_conf(spark: SparkSession, key: str, default_value: str) -> str:
+ """Get the conf "key" from the given spark session,
or return the default value if the conf is not set.
- This expects the conf value to be a boolean or string;
- if the value is a string, this checks for all capitalization
- patterns of "true" and "false" to match Scala.
Parameters
----------
- sc : :class:`SparkContext`
- The :class:`SparkContext` for the distributor.
+ spark : :class:`SparkSession`
+ The :class:`SparkSession` for the distributor.
key : str
string for conf name
default_value : str
@@ -61,25 +69,21 @@ def get_conf_boolean(sc: SparkContext, key: str,
default_value: str) -> bool:
Returns
-------
- bool
- Returns the boolean value that corresponds to the conf
-
- Raises
- ------
- ValueError
- Thrown when the conf value is not a valid boolean
+ str
+ Returns the string value that corresponds to the conf
"""
- val = sc.getConf().get(key, default_value)
- lowercase_val = val.lower()
- if lowercase_val == "true":
- return True
- if lowercase_val == "false":
- return False
- raise ValueError(
- f"The conf value for '{key}' was expected to be a boolean "
- f"value but found value of type {type(val)} "
- f"with value: {val}"
- )
+ value = spark.conf.get(key, default_value)
+ assert value is not None
+ return value
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+# once we are in agreement about which functions should be in utils.py
+def _get_conf_boolean(spark: SparkSession, key: str, default_value: str) ->
bool:
+ value = _get_conf(spark=spark, key=key, default_value=default_value)
+ value = value.lower()
+ assert value in ["true", "false"]
+ return value == "true"
def get_logger(name: str) -> logging.Logger:
@@ -95,13 +99,13 @@ def get_logger(name: str) -> logging.Logger:
return logger
-def get_gpus_owned(context: Union[SparkContext, BarrierTaskContext]) ->
List[str]:
+def get_gpus_owned(context: Union[SparkSession, BarrierTaskContext]) ->
List[str]:
"""Gets the number of GPUs that Spark scheduled to the calling task.
Parameters
----------
- context : :class:`SparkContext` or :class:`BarrierTaskContext`
- The :class:`SparkContext` or :class:`BarrierTaskContext` that has GPUs
available.
+ context : :class:`SparkSession` or :class:`BarrierTaskContext`
+ The :class:`SparkSession` or :class:`BarrierTaskContext` that has GPUs
available.
Returns
-------
@@ -115,8 +119,9 @@ def get_gpus_owned(context: Union[SparkContext,
BarrierTaskContext]) -> List[str
"""
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
pattern = re.compile("^[1-9][0-9]*|0$")
- if isinstance(context, SparkContext):
- addresses = context.resources["gpu"].addresses
+ if isinstance(context, SparkSession):
+ # TODO(SPARK-42994): Support sc.resources in Spark Connect
+ addresses = context.sparkContext.resources["gpu"].addresses
else:
addresses = context.resources()["gpu"].addresses
if any(not pattern.match(address) for address in addresses):
@@ -145,20 +150,17 @@ class Distributor:
local_mode: bool = True,
use_gpu: bool = True,
):
+ self.spark = _get_active_session()
self.logger = get_logger(self.__class__.__name__)
self.num_processes = num_processes
self.local_mode = local_mode
self.use_gpu = use_gpu
- self.spark = SparkSession.getActiveSession()
- if not self.spark:
- raise RuntimeError("An active SparkSession is required for the
distributor.")
- self.sc = self.spark.sparkContext
self.num_tasks = self._get_num_tasks()
self.ssl_conf = None
def _create_input_params(self) -> Dict[str, Any]:
input_params = self.__dict__.copy()
- for unneeded_param in ["spark", "sc", "ssl_conf", "logger"]:
+ for unneeded_param in ["spark", "ssl_conf", "logger"]:
del input_params[unneeded_param]
return input_params
@@ -176,20 +178,20 @@ class Distributor:
RuntimeError
Raised when the SparkConf was misconfigured.
"""
-
if self.use_gpu:
if not self.local_mode:
key = "spark.task.resource.gpu.amount"
- task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+ task_gpu_amount = int(_get_conf(self.spark, key, "0"))
if task_gpu_amount < 1:
raise RuntimeError(f"'{key}' was unset, so gpu usage is
unavailable.")
# TODO(SPARK-41916): Address situation when
spark.task.resource.gpu.amount > 1
return math.ceil(self.num_processes / task_gpu_amount)
else:
key = "spark.driver.resource.gpu.amount"
- if "gpu" not in self.sc.resources:
+ # TODO(SPARK-42994): Support sc.resources in Spark Connect
+ if "gpu" not in self.spark.sparkContext.resources:
raise RuntimeError("GPUs were unable to be found on the
driver.")
- num_available_gpus = int(self.sc.getConf().get(key, "0"))
+ num_available_gpus = int(_get_conf(self.spark, key, "0"))
if num_available_gpus == 0:
raise RuntimeError("GPU resources were not configured
properly on the driver.")
if self.num_processes > num_available_gpus:
@@ -215,12 +217,12 @@ class Distributor:
Thrown when the user requires ssl encryption or when the user
initializes
the Distributor parent class.
"""
- if not "ssl_conf":
+ if not hasattr(self, "ssl_conf"):
raise RuntimeError(
"Distributor doesn't have this functionality. Use
TorchDistributor instead."
)
- is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled",
"false")
- ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false") #
type: ignore
+ is_ssl_enabled = _get_conf_boolean(self.spark, "spark.ssl.enabled",
"false")
+ ignore_ssl = _get_conf_boolean(self.spark, self.ssl_conf, "false") #
type: ignore
if is_ssl_enabled:
name = self.__class__.__name__
if ignore_ssl:
@@ -263,6 +265,10 @@ class TorchDistributor(Distributor):
.. versionadded:: 3.4.0
+ .. versionchanged:: 3.5.0
+ Supports Spark Connect. Note that local mode with GPU is not supported
yet, will be fixed
+ in SPARK-42994.
+
Parameters
----------
num_processes : int, optional
@@ -451,7 +457,7 @@ class TorchDistributor(Distributor):
old_cuda_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES, "")
try:
if self.use_gpu:
- gpus_owned = get_gpus_owned(self.sc)
+ gpus_owned = get_gpus_owned(self.spark)
random.seed(hash(train_object))
selected_gpus = [str(e) for e in random.sample(gpus_owned,
self.num_processes)]
os.environ[CUDA_VISIBLE_DEVICES] = ",".join(selected_gpus)
@@ -499,6 +505,7 @@ class TorchDistributor(Distributor):
# Spark task program
def wrapped_train_fn(_): # type: ignore[no-untyped-def]
import os
+ import pandas as pd
from pyspark import BarrierTaskContext
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
@@ -554,7 +561,21 @@ class TorchDistributor(Distributor):
pass
if context.partitionId() == 0:
- yield output
+ output_bytes = cloudpickle.dumps(output)
+ output_size = len(output_bytes)
+
+ # In Spark Connect, DataFrame.collect stacks rows to size
+ # 'spark.connect.grpc.arrow.maxBatchSize' (default 4MiB),
+ # here use 4KiB for each chunk, which mean each arrow batch
+ # may contain about 1000 chunks.
+ chunks = []
+ chunk_size = 4096
+ index = 0
+ while index < output_size:
+ chunks.append(output_bytes[index : index + chunk_size])
+ index += chunk_size
+
+ yield pd.DataFrame(data={"chunk": chunks})
return wrapped_train_fn
@@ -568,7 +589,8 @@ class TorchDistributor(Distributor):
raise RuntimeError("Unknown combination of parameters")
log_streaming_server = LogStreamingServer()
- self.driver_address = get_driver_host(self.sc)
+ self.driver_address = _get_conf(self.spark, "spark.driver.host", "")
+ assert self.driver_address != ""
log_streaming_server.start(spark_host_address=self.driver_address)
time.sleep(1) # wait for the server to start
self.log_streaming_server_port = log_streaming_server.port
@@ -578,19 +600,20 @@ class TorchDistributor(Distributor):
)
self._check_encryption()
self.logger.info(
- f"Started distributed training with {self.num_processes} executor
proceses"
+ f"Started distributed training with {self.num_processes} executor
processes"
)
try:
- result = (
- self.sc.parallelize(range(self.num_tasks), self.num_tasks)
- .barrier()
- .mapPartitions(spark_task_function)
- .collect()[0]
+ rows = (
+ self.spark.range(start=0, end=self.num_tasks, step=1,
numPartitions=self.num_tasks)
+ .mapInPandas(func=spark_task_function, schema="chunk binary",
barrier=True)
+ .collect()
)
+ output_bytes = b"".join([row.chunk for row in rows])
+ result = cloudpickle.loads(output_bytes)
finally:
log_streaming_server.shutdown()
self.logger.info(
- f"Finished distributed training with {self.num_processes} executor
proceses"
+ f"Finished distributed training with {self.num_processes} executor
processes"
)
return result
diff --git a/python/pyspark/ml/torch/log_communication.py
b/python/pyspark/ml/torch/log_communication.py
index e269aed42d1..ca91121d3e3 100644
--- a/python/pyspark/ml/torch/log_communication.py
+++ b/python/pyspark/ml/torch/log_communication.py
@@ -24,19 +24,13 @@ from struct import pack, unpack
import sys
import threading
import traceback
-from typing import Optional, Generator
+from typing import Generator
import warnings
-from pyspark.context import SparkContext
# Use b'\x00' as separator instead of b'\n', because the bytes are encoded in
utf-8
_SERVER_POLL_INTERVAL = 0.1
_TRUNCATE_MSG_LEN = 4000
-
-def get_driver_host(sc: SparkContext) -> Optional[str]:
- return sc.getConf().get("spark.driver.host")
-
-
_log_print_lock = threading.Lock() # pylint: disable=invalid-name
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py
b/python/pyspark/ml/torch/tests/test_distributor.py
index baf68757f67..56569770157 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -117,21 +117,12 @@ def create_training_function(mnist_dir_path: str) ->
Callable:
optimizer.step()
print(f"epoch {epoch} finished.")
- return "success"
+ return "success" * 4096
return train_fn
[email protected](not have_torch, "torch is required")
-class TorchDistributorBaselineUnitTests(unittest.TestCase):
- def setUp(self) -> None:
- conf = SparkConf()
- self.sc = SparkContext("local[4]", conf=conf)
- self.spark = SparkSession(self.sc)
-
- def tearDown(self) -> None:
- self.spark.stop()
-
+class TorchDistributorBaselineUnitTestsMixin:
def setup_env_vars(self, input_map: Dict[str, str]) -> None:
for key, value in input_map.items():
os.environ[key] = value
@@ -175,8 +166,8 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase):
]
for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value
in inputs:
with self.subTest():
- self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
- self.spark.sparkContext._conf.set(pytorch_conf_key,
pytorch_conf_value)
+ self.spark.conf.set(ssl_conf_key, ssl_conf_value)
+ self.spark.conf.set(pytorch_conf_key, pytorch_conf_value)
distributor = TorchDistributor(1, True, False)
distributor._check_encryption()
@@ -186,8 +177,8 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase):
for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value
in inputs:
with self.subTest():
with self.assertRaisesRegex(Exception, "encryption"):
- self.spark.sparkContext._conf.set(ssl_conf_key,
ssl_conf_value)
- self.spark.sparkContext._conf.set(pytorch_conf_key,
pytorch_conf_value)
+ self.spark.conf.set(ssl_conf_key, ssl_conf_value)
+ self.spark.conf.set(pytorch_conf_key, pytorch_conf_value)
distributor = TorchDistributor(1, True, False)
distributor._check_encryption()
@@ -279,9 +270,18 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase):
@unittest.skipIf(not have_torch, "torch is required")
-class TorchDistributorLocalUnitTests(unittest.TestCase):
+class
TorchDistributorBaselineUnitTests(TorchDistributorBaselineUnitTestsMixin,
unittest.TestCase):
def setUp(self) -> None:
- class_name = self.__class__.__name__
+ conf = SparkConf()
+ sc = SparkContext("local[4]", conf=conf)
+ self.spark = SparkSession(sc)
+
+ def tearDown(self) -> None:
+ self.spark.stop()
+
+
+class TorchDistributorLocalUnitTestsMixin:
+ def _get_spark_conf(self) -> SparkConf:
self.gpu_discovery_script_file =
tempfile.NamedTemporaryFile(delete=False)
self.gpu_discovery_script_file.write(
b'echo {\\"name\\": \\"gpu\\", \\"addresses\\":
[\\"0\\",\\"1\\",\\"2\\"]}'
@@ -294,21 +294,13 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
self.gpu_discovery_script_file.name,
stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH |
stat.S_IXOTH,
)
- conf = SparkConf().set("spark.test.home", SPARK_HOME)
+ conf = SparkConf().set("spark.test.home", SPARK_HOME)
conf = conf.set("spark.driver.resource.gpu.amount", "3")
conf = conf.set(
"spark.driver.resource.gpu.discoveryScript",
self.gpu_discovery_script_file.name
)
-
- self.sc = SparkContext("local-cluster[2,2,1024]", class_name,
conf=conf)
- self.spark = SparkSession(self.sc)
- self.mnist_dir_path = tempfile.mkdtemp()
-
- def tearDown(self) -> None:
- shutil.rmtree(self.mnist_dir_path)
- os.unlink(self.gpu_discovery_script_file.name)
- self.spark.stop()
+ return conf
def setup_env_vars(self, input_map: Dict[str, str]) -> None:
for key, value in input_map.items():
@@ -336,11 +328,11 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
def test_get_gpus_owned_local(self) -> None:
addresses = ["0", "1", "2"]
- self.assertEqual(get_gpus_owned(self.sc), addresses)
+ self.assertEqual(get_gpus_owned(self.spark), addresses)
env_vars = {"CUDA_VISIBLE_DEVICES": "3,4,5"}
self.setup_env_vars(env_vars)
- self.assertEqual(get_gpus_owned(self.sc), ["3", "4", "5"])
+ self.assertEqual(get_gpus_owned(self.spark), ["3", "4", "5"])
self.delete_env_vars(env_vars)
def test_local_training_succeeds(self) -> None:
@@ -382,27 +374,40 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
output = TorchDistributor(num_processes=2, local_mode=True,
use_gpu=False).run(
train_fn, 0.001
)
- self.assertEqual(output, "success")
+ self.assertEqual(output, "success" * 4096)
@unittest.skipIf(not have_torch, "torch is required")
-class TorchDistributorDistributedUnitTests(unittest.TestCase):
+class TorchDistributorLocalUnitTests(TorchDistributorLocalUnitTestsMixin,
unittest.TestCase):
def setUp(self) -> None:
class_name = self.__class__.__name__
+ conf = self._get_spark_conf()
+ sc = SparkContext("local-cluster[2,2,1024]", class_name, conf=conf)
+ self.spark = SparkSession(sc)
+ self.mnist_dir_path = tempfile.mkdtemp()
+
+ def tearDown(self) -> None:
+ shutil.rmtree(self.mnist_dir_path)
+ os.unlink(self.gpu_discovery_script_file.name)
+ self.spark.stop()
+
+
+class TorchDistributorDistributedUnitTestsMixin:
+ def _get_spark_conf(self) -> SparkConf:
self.gpu_discovery_script_file =
tempfile.NamedTemporaryFile(delete=False)
self.gpu_discovery_script_file.write(
b'echo {\\"name\\": \\"gpu\\", \\"addresses\\":
[\\"0\\",\\"1\\",\\"2\\"]}'
)
self.gpu_discovery_script_file.close()
# create temporary directory for Worker resources coordination
- self.tempdir = tempfile.NamedTemporaryFile(delete=False)
- os.unlink(self.tempdir.name)
+ tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(tempdir.name)
os.chmod(
self.gpu_discovery_script_file.name,
stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH |
stat.S_IXOTH,
)
- conf = SparkConf().set("spark.test.home", SPARK_HOME)
+ conf = SparkConf().set("spark.test.home", SPARK_HOME)
conf = conf.set(
"spark.worker.resource.gpu.discoveryScript",
self.gpu_discovery_script_file.name
)
@@ -410,15 +415,7 @@ class
TorchDistributorDistributedUnitTests(unittest.TestCase):
conf = conf.set("spark.task.cpus", "2")
conf = conf.set("spark.task.resource.gpu.amount", "1")
conf = conf.set("spark.executor.resource.gpu.amount", "1")
-
- self.sc = SparkContext("local-cluster[2,2,1024]", class_name,
conf=conf)
- self.spark = SparkSession(self.sc)
- self.mnist_dir_path = tempfile.mkdtemp()
-
- def tearDown(self) -> None:
- shutil.rmtree(self.mnist_dir_path)
- os.unlink(self.gpu_discovery_script_file.name)
- self.spark.stop()
+ return conf
def test_dist_training_succeeds(self) -> None:
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
@@ -442,13 +439,11 @@ class
TorchDistributorDistributedUnitTests(unittest.TestCase):
for spark_conf_value, num_processes, expected_output in inputs:
with self.subTest():
- self.spark.sparkContext._conf.set(
- "spark.task.resource.gpu.amount", str(spark_conf_value)
- )
+ self.spark.conf.set("spark.task.resource.gpu.amount",
str(spark_conf_value))
distributor = TorchDistributor(num_processes, False, True)
self.assertEqual(distributor._get_num_tasks(), expected_output)
- self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount",
"1")
+ self.spark.conf.set("spark.task.resource.gpu.amount", "1")
def test_distributed_file_with_pytorch(self) -> None:
test_file_path = "python/test_support/test_pytorch_training_file.py"
@@ -462,11 +457,27 @@ class
TorchDistributorDistributedUnitTests(unittest.TestCase):
output = TorchDistributor(num_processes=2, local_mode=False,
use_gpu=False).run(
train_fn, 0.001
)
- self.assertEqual(output, "success")
+ self.assertEqual(output, "success" * 4096)
@unittest.skipIf(not have_torch, "torch is required")
-class TorchWrapperUnitTests(unittest.TestCase):
+class TorchDistributorDistributedUnitTests(
+ TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
+):
+ def setUp(self) -> None:
+ class_name = self.__class__.__name__
+ conf = self._get_spark_conf()
+ sc = SparkContext("local-cluster[2,2,1024]", class_name, conf=conf)
+ self.spark = SparkSession(sc)
+ self.mnist_dir_path = tempfile.mkdtemp()
+
+ def tearDown(self) -> None:
+ shutil.rmtree(self.mnist_dir_path)
+ os.unlink(self.gpu_discovery_script_file.name)
+ self.spark.stop()
+
+
+class TorchWrapperUnitTestsMixin:
def test_clean_and_terminate(self) -> None:
def kill_task(task: "subprocess.Popen") -> None:
time.sleep(1)
@@ -489,6 +500,11 @@ class TorchWrapperUnitTests(unittest.TestCase):
self.assertEqual(mock_clean_and_terminate.call_count, 0)
[email protected](not have_torch, "torch is required")
+class TorchWrapperUnitTests(TorchWrapperUnitTestsMixin, unittest.TestCase):
+ pass
+
+
if __name__ == "__main__":
from pyspark.ml.torch.tests.test_distributor import * # noqa: F401,F403
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]