This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ad013d3eaea [SPARK-42993][ML][CONNECT] Make PyTorch Distributor 
compatible with Spark Connect
ad013d3eaea is described below

commit ad013d3eaea85800c6428a89918c94994361e1cd
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Apr 7 08:15:52 2023 +0800

    [SPARK-42993][ML][CONNECT] Make PyTorch Distributor compatible with Spark 
Connect
    
    ### What changes were proposed in this pull request?
    Make Torch Distributor support Spark Connect
    
    ### Why are the changes needed?
    functionality parity.
    
    **Note**, `local_mode` with `use_gpu` is not supported for now since 
`sc.resources` is missing in Connect
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    ### How was this patch tested?
    reused UT
    
    Closes #40607 from zhengruifeng/connect_torch.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../service/SparkConnectStreamHandler.scala        |  19 +--
 dev/sparktestsupport/modules.py                    |   1 +
 .../tests/connect/test_parity_torch_distributor.py | 134 ++++++++++++++++++++
 python/pyspark/ml/torch/distributor.py             | 137 ++++++++++++---------
 python/pyspark/ml/torch/log_communication.py       |   8 +-
 python/pyspark/ml/torch/tests/test_distributor.py  | 114 +++++++++--------
 6 files changed, 293 insertions(+), 120 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 96e16623222..8a14357640e 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -52,13 +52,18 @@ class SparkConnectStreamHandler(responseObserver: 
StreamObserver[ExecutePlanResp
     session.withActive {
 
       // Add debug information to the query execution so that the jobs are 
traceable.
-      val debugString = v.toString
-      session.sparkContext.setLocalProperty(
-        "callSite.short",
-        s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
-      session.sparkContext.setLocalProperty(
-        "callSite.long",
-        StringUtils.abbreviate(debugString, 2048))
+      try {
+        val debugString = v.toString
+        session.sparkContext.setLocalProperty(
+          "callSite.short",
+          s"Spark Connect - ${StringUtils.abbreviate(debugString, 128)}")
+        session.sparkContext.setLocalProperty(
+          "callSite.long",
+          StringUtils.abbreviate(debugString, 2048))
+      } catch {
+        case e: Throwable =>
+          logWarning("Fail to extract or attach the debug information", e)
+      }
 
       v.getPlan.getOpTypeCase match {
         case proto.Plan.OpTypeCase.COMMAND => handleCommand(session, v)
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 1a28a644e55..249e2675b76 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -777,6 +777,7 @@ pyspark_connect = Module(
         "pyspark.ml.connect.functions",
         # ml unittests
         "pyspark.ml.tests.connect.test_connect_function",
+        "pyspark.ml.tests.connect.test_parity_torch_distributor",
     ],
     excluded_python_implementations=[
         "PyPy"  # Skip these tests under PyPy since they require numpy, 
pandas, and pyarrow and
diff --git a/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py 
b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
new file mode 100644
index 00000000000..8aa1079beaf
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_parity_torch_distributor.py
@@ -0,0 +1,134 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import shutil
+import tempfile
+import unittest
+
+have_torch = True
+try:
+    import torch  # noqa: F401
+except ImportError:
+    have_torch = False
+
+from pyspark.sql import SparkSession
+
+from pyspark.ml.torch.distributor import TorchDistributor
+
+from pyspark.ml.torch.tests.test_distributor import (
+    TorchDistributorBaselineUnitTestsMixin,
+    TorchDistributorLocalUnitTestsMixin,
+    TorchDistributorDistributedUnitTestsMixin,
+    TorchWrapperUnitTestsMixin,
+)
+
+
[email protected](not have_torch, "torch is required")
+class TorchDistributorBaselineUnitTestsOnConnect(
+    TorchDistributorBaselineUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        self.spark = SparkSession.builder.remote("local[4]").getOrCreate()
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+    def test_get_num_tasks_fails(self) -> None:
+        inputs = [1, 5, 4]
+
+        # This is when the conf isn't set and we request GPUs
+        for num_processes in inputs:
+            with self.subTest():
+                # TODO(SPARK-42994): Support sc.resources
+                # with self.assertRaisesRegex(RuntimeError, "driver"):
+                #     TorchDistributor(num_processes, True, True)
+                with self.assertRaisesRegex(RuntimeError, "unset"):
+                    TorchDistributor(num_processes, False, True)
+
+
[email protected](not have_torch, "torch is required")
+class TorchDistributorLocalUnitTestsOnConnect(
+    TorchDistributorLocalUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        conf = self._get_spark_conf()
+        builder = SparkSession.builder.appName(class_name)
+        for k, v in conf.getAll():
+            if k not in ["spark.master", "spark.remote", "spark.app.name"]:
+                builder = builder.config(k, v)
+        self.spark = builder.remote("local-cluster[2,2,1024]").getOrCreate()
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_num_tasks_locally(self):
+        super().test_get_num_tasks_locally()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_get_gpus_owned_local(self):
+        super().test_get_gpus_owned_local()
+
+    # TODO(SPARK-42994): Support sc.resources
+    @unittest.skip("need to support sc.resources")
+    def test_local_training_succeeds(self):
+        super().test_local_training_succeeds()
+
+
[email protected](not have_torch, "torch is required")
+class TorchDistributorDistributedUnitTestsOnConnect(
+    TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        conf = self._get_spark_conf()
+        builder = SparkSession.builder.appName(class_name)
+        for k, v in conf.getAll():
+            if k not in ["spark.master", "spark.remote", "spark.app.name"]:
+                builder = builder.config(k, v)
+
+        self.spark = builder.remote("local-cluster[2,2,1024]").getOrCreate()
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+
[email protected](not have_torch, "torch is required")
+class TorchWrapperUnitTestsOnConnect(TorchWrapperUnitTestsMixin, 
unittest.TestCase):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.connect.test_parity_torch_distributor import *  # 
noqa: F401,F403
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/torch/distributor.py 
b/python/pyspark/ml/torch/distributor.py
index 157cc96717f..7491803aed2 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -32,28 +32,36 @@ from typing import Union, Callable, List, Dict, Optional, 
Any, Tuple, Generator
 
 from pyspark import cloudpickle
 from pyspark.sql import SparkSession
+from pyspark.taskcontext import BarrierTaskContext
 from pyspark.ml.torch.log_communication import (  # type: ignore
-    get_driver_host,
     LogStreamingClient,
     LogStreamingServer,
 )
-from pyspark.context import SparkContext
-from pyspark.taskcontext import BarrierTaskContext
 
 
-# TODO(SPARK-41589): will move the functions and tests to an external file
-#       once we are in agreement about which functions should be in utils.py
-def get_conf_boolean(sc: SparkContext, key: str, default_value: str) -> bool:
-    """Get the conf "key" from the given spark context,
+def _get_active_session() -> SparkSession:
+    from pyspark.sql.utils import is_remote
+
+    if not is_remote():
+        spark = SparkSession.getActiveSession()
+    else:
+        import pyspark.sql.connect.session
+
+        spark = pyspark.sql.connect.session._active_spark_session  # type: 
ignore[assignment]
+
+    if spark is None:
+        raise RuntimeError("An active SparkSession is required for the 
distributor.")
+    return spark
+
+
+def _get_conf(spark: SparkSession, key: str, default_value: str) -> str:
+    """Get the conf "key" from the given spark session,
     or return the default value if the conf is not set.
-    This expects the conf value to be a boolean or string;
-    if the value is a string, this checks for all capitalization
-    patterns of "true" and "false" to match Scala.
 
     Parameters
     ----------
-    sc : :class:`SparkContext`
-        The :class:`SparkContext` for the distributor.
+    spark : :class:`SparkSession`
+        The :class:`SparkSession` for the distributor.
     key : str
         string for conf name
     default_value : str
@@ -61,25 +69,21 @@ def get_conf_boolean(sc: SparkContext, key: str, 
default_value: str) -> bool:
 
     Returns
     -------
-    bool
-        Returns the boolean value that corresponds to the conf
-
-    Raises
-    ------
-    ValueError
-        Thrown when the conf value is not a valid boolean
+    str
+        Returns the string value that corresponds to the conf
     """
-    val = sc.getConf().get(key, default_value)
-    lowercase_val = val.lower()
-    if lowercase_val == "true":
-        return True
-    if lowercase_val == "false":
-        return False
-    raise ValueError(
-        f"The conf value for '{key}' was expected to be a boolean "
-        f"value but found value of type {type(val)} "
-        f"with value: {val}"
-    )
+    value = spark.conf.get(key, default_value)
+    assert value is not None
+    return value
+
+
+# TODO(SPARK-41589): will move the functions and tests to an external file
+#       once we are in agreement about which functions should be in utils.py
+def _get_conf_boolean(spark: SparkSession, key: str, default_value: str) -> 
bool:
+    value = _get_conf(spark=spark, key=key, default_value=default_value)
+    value = value.lower()
+    assert value in ["true", "false"]
+    return value == "true"
 
 
 def get_logger(name: str) -> logging.Logger:
@@ -95,13 +99,13 @@ def get_logger(name: str) -> logging.Logger:
     return logger
 
 
-def get_gpus_owned(context: Union[SparkContext, BarrierTaskContext]) -> 
List[str]:
+def get_gpus_owned(context: Union[SparkSession, BarrierTaskContext]) -> 
List[str]:
     """Gets the number of GPUs that Spark scheduled to the calling task.
 
     Parameters
     ----------
-    context : :class:`SparkContext` or :class:`BarrierTaskContext`
-        The :class:`SparkContext` or :class:`BarrierTaskContext` that has GPUs 
available.
+    context : :class:`SparkSession` or :class:`BarrierTaskContext`
+        The :class:`SparkSession` or :class:`BarrierTaskContext` that has GPUs 
available.
 
     Returns
     -------
@@ -115,8 +119,9 @@ def get_gpus_owned(context: Union[SparkContext, 
BarrierTaskContext]) -> List[str
     """
     CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
     pattern = re.compile("^[1-9][0-9]*|0$")
-    if isinstance(context, SparkContext):
-        addresses = context.resources["gpu"].addresses
+    if isinstance(context, SparkSession):
+        # TODO(SPARK-42994): Support sc.resources in Spark Connect
+        addresses = context.sparkContext.resources["gpu"].addresses
     else:
         addresses = context.resources()["gpu"].addresses
     if any(not pattern.match(address) for address in addresses):
@@ -145,20 +150,17 @@ class Distributor:
         local_mode: bool = True,
         use_gpu: bool = True,
     ):
+        self.spark = _get_active_session()
         self.logger = get_logger(self.__class__.__name__)
         self.num_processes = num_processes
         self.local_mode = local_mode
         self.use_gpu = use_gpu
-        self.spark = SparkSession.getActiveSession()
-        if not self.spark:
-            raise RuntimeError("An active SparkSession is required for the 
distributor.")
-        self.sc = self.spark.sparkContext
         self.num_tasks = self._get_num_tasks()
         self.ssl_conf = None
 
     def _create_input_params(self) -> Dict[str, Any]:
         input_params = self.__dict__.copy()
-        for unneeded_param in ["spark", "sc", "ssl_conf", "logger"]:
+        for unneeded_param in ["spark", "ssl_conf", "logger"]:
             del input_params[unneeded_param]
         return input_params
 
@@ -176,20 +178,20 @@ class Distributor:
         RuntimeError
             Raised when the SparkConf was misconfigured.
         """
-
         if self.use_gpu:
             if not self.local_mode:
                 key = "spark.task.resource.gpu.amount"
-                task_gpu_amount = int(self.sc.getConf().get(key, "0"))
+                task_gpu_amount = int(_get_conf(self.spark, key, "0"))
                 if task_gpu_amount < 1:
                     raise RuntimeError(f"'{key}' was unset, so gpu usage is 
unavailable.")
                 # TODO(SPARK-41916): Address situation when 
spark.task.resource.gpu.amount > 1
                 return math.ceil(self.num_processes / task_gpu_amount)
             else:
                 key = "spark.driver.resource.gpu.amount"
-                if "gpu" not in self.sc.resources:
+                # TODO(SPARK-42994): Support sc.resources in Spark Connect
+                if "gpu" not in self.spark.sparkContext.resources:
                     raise RuntimeError("GPUs were unable to be found on the 
driver.")
-                num_available_gpus = int(self.sc.getConf().get(key, "0"))
+                num_available_gpus = int(_get_conf(self.spark, key, "0"))
                 if num_available_gpus == 0:
                     raise RuntimeError("GPU resources were not configured 
properly on the driver.")
                 if self.num_processes > num_available_gpus:
@@ -215,12 +217,12 @@ class Distributor:
             Thrown when the user requires ssl encryption or when the user 
initializes
             the Distributor parent class.
         """
-        if not "ssl_conf":
+        if not hasattr(self, "ssl_conf"):
             raise RuntimeError(
                 "Distributor doesn't have this functionality. Use 
TorchDistributor instead."
             )
-        is_ssl_enabled = get_conf_boolean(self.sc, "spark.ssl.enabled", 
"false")
-        ignore_ssl = get_conf_boolean(self.sc, self.ssl_conf, "false")  # 
type: ignore
+        is_ssl_enabled = _get_conf_boolean(self.spark, "spark.ssl.enabled", 
"false")
+        ignore_ssl = _get_conf_boolean(self.spark, self.ssl_conf, "false")  # 
type: ignore
         if is_ssl_enabled:
             name = self.__class__.__name__
             if ignore_ssl:
@@ -263,6 +265,10 @@ class TorchDistributor(Distributor):
 
     .. versionadded:: 3.4.0
 
+    .. versionchanged:: 3.5.0
+        Supports Spark Connect. Note that local mode with GPU is not supported 
yet, will be fixed
+        in SPARK-42994.
+
     Parameters
     ----------
     num_processes : int, optional
@@ -451,7 +457,7 @@ class TorchDistributor(Distributor):
         old_cuda_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES, "")
         try:
             if self.use_gpu:
-                gpus_owned = get_gpus_owned(self.sc)
+                gpus_owned = get_gpus_owned(self.spark)
                 random.seed(hash(train_object))
                 selected_gpus = [str(e) for e in random.sample(gpus_owned, 
self.num_processes)]
                 os.environ[CUDA_VISIBLE_DEVICES] = ",".join(selected_gpus)
@@ -499,6 +505,7 @@ class TorchDistributor(Distributor):
         # Spark task program
         def wrapped_train_fn(_):  # type: ignore[no-untyped-def]
             import os
+            import pandas as pd
             from pyspark import BarrierTaskContext
 
             CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
@@ -554,7 +561,21 @@ class TorchDistributor(Distributor):
                     pass
 
             if context.partitionId() == 0:
-                yield output
+                output_bytes = cloudpickle.dumps(output)
+                output_size = len(output_bytes)
+
+                # In Spark Connect, DataFrame.collect stacks rows to size
+                # 'spark.connect.grpc.arrow.maxBatchSize' (default 4MiB),
+                # here use 4KiB for each chunk, which mean each arrow batch
+                # may contain about 1000 chunks.
+                chunks = []
+                chunk_size = 4096
+                index = 0
+                while index < output_size:
+                    chunks.append(output_bytes[index : index + chunk_size])
+                    index += chunk_size
+
+                yield pd.DataFrame(data={"chunk": chunks})
 
         return wrapped_train_fn
 
@@ -568,7 +589,8 @@ class TorchDistributor(Distributor):
             raise RuntimeError("Unknown combination of parameters")
 
         log_streaming_server = LogStreamingServer()
-        self.driver_address = get_driver_host(self.sc)
+        self.driver_address = _get_conf(self.spark, "spark.driver.host", "")
+        assert self.driver_address != ""
         log_streaming_server.start(spark_host_address=self.driver_address)
         time.sleep(1)  # wait for the server to start
         self.log_streaming_server_port = log_streaming_server.port
@@ -578,19 +600,20 @@ class TorchDistributor(Distributor):
         )
         self._check_encryption()
         self.logger.info(
-            f"Started distributed training with {self.num_processes} executor 
proceses"
+            f"Started distributed training with {self.num_processes} executor 
processes"
         )
         try:
-            result = (
-                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-                .barrier()
-                .mapPartitions(spark_task_function)
-                .collect()[0]
+            rows = (
+                self.spark.range(start=0, end=self.num_tasks, step=1, 
numPartitions=self.num_tasks)
+                .mapInPandas(func=spark_task_function, schema="chunk binary", 
barrier=True)
+                .collect()
             )
+            output_bytes = b"".join([row.chunk for row in rows])
+            result = cloudpickle.loads(output_bytes)
         finally:
             log_streaming_server.shutdown()
         self.logger.info(
-            f"Finished distributed training with {self.num_processes} executor 
proceses"
+            f"Finished distributed training with {self.num_processes} executor 
processes"
         )
         return result
 
diff --git a/python/pyspark/ml/torch/log_communication.py 
b/python/pyspark/ml/torch/log_communication.py
index e269aed42d1..ca91121d3e3 100644
--- a/python/pyspark/ml/torch/log_communication.py
+++ b/python/pyspark/ml/torch/log_communication.py
@@ -24,19 +24,13 @@ from struct import pack, unpack
 import sys
 import threading
 import traceback
-from typing import Optional, Generator
+from typing import Generator
 import warnings
-from pyspark.context import SparkContext
 
 # Use b'\x00' as separator instead of b'\n', because the bytes are encoded in 
utf-8
 _SERVER_POLL_INTERVAL = 0.1
 _TRUNCATE_MSG_LEN = 4000
 
-
-def get_driver_host(sc: SparkContext) -> Optional[str]:
-    return sc.getConf().get("spark.driver.host")
-
-
 _log_print_lock = threading.Lock()  # pylint: disable=invalid-name
 
 
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py 
b/python/pyspark/ml/torch/tests/test_distributor.py
index baf68757f67..56569770157 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -117,21 +117,12 @@ def create_training_function(mnist_dir_path: str) -> 
Callable:
                 optimizer.step()
             print(f"epoch {epoch} finished.")
 
-        return "success"
+        return "success" * 4096
 
     return train_fn
 
 
[email protected](not have_torch, "torch is required")
-class TorchDistributorBaselineUnitTests(unittest.TestCase):
-    def setUp(self) -> None:
-        conf = SparkConf()
-        self.sc = SparkContext("local[4]", conf=conf)
-        self.spark = SparkSession(self.sc)
-
-    def tearDown(self) -> None:
-        self.spark.stop()
-
+class TorchDistributorBaselineUnitTestsMixin:
     def setup_env_vars(self, input_map: Dict[str, str]) -> None:
         for key, value in input_map.items():
             os.environ[key] = value
@@ -175,8 +166,8 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase):
         ]
         for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value 
in inputs:
             with self.subTest():
-                self.spark.sparkContext._conf.set(ssl_conf_key, ssl_conf_value)
-                self.spark.sparkContext._conf.set(pytorch_conf_key, 
pytorch_conf_value)
+                self.spark.conf.set(ssl_conf_key, ssl_conf_value)
+                self.spark.conf.set(pytorch_conf_key, pytorch_conf_value)
                 distributor = TorchDistributor(1, True, False)
                 distributor._check_encryption()
 
@@ -186,8 +177,8 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase):
         for ssl_conf_key, ssl_conf_value, pytorch_conf_key, pytorch_conf_value 
in inputs:
             with self.subTest():
                 with self.assertRaisesRegex(Exception, "encryption"):
-                    self.spark.sparkContext._conf.set(ssl_conf_key, 
ssl_conf_value)
-                    self.spark.sparkContext._conf.set(pytorch_conf_key, 
pytorch_conf_value)
+                    self.spark.conf.set(ssl_conf_key, ssl_conf_value)
+                    self.spark.conf.set(pytorch_conf_key, pytorch_conf_value)
                     distributor = TorchDistributor(1, True, False)
                     distributor._check_encryption()
 
@@ -279,9 +270,18 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase):
 
 
 @unittest.skipIf(not have_torch, "torch is required")
-class TorchDistributorLocalUnitTests(unittest.TestCase):
+class 
TorchDistributorBaselineUnitTests(TorchDistributorBaselineUnitTestsMixin, 
unittest.TestCase):
     def setUp(self) -> None:
-        class_name = self.__class__.__name__
+        conf = SparkConf()
+        sc = SparkContext("local[4]", conf=conf)
+        self.spark = SparkSession(sc)
+
+    def tearDown(self) -> None:
+        self.spark.stop()
+
+
+class TorchDistributorLocalUnitTestsMixin:
+    def _get_spark_conf(self) -> SparkConf:
         self.gpu_discovery_script_file = 
tempfile.NamedTemporaryFile(delete=False)
         self.gpu_discovery_script_file.write(
             b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": 
[\\"0\\",\\"1\\",\\"2\\"]}'
@@ -294,21 +294,13 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
             self.gpu_discovery_script_file.name,
             stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | 
stat.S_IXOTH,
         )
-        conf = SparkConf().set("spark.test.home", SPARK_HOME)
 
+        conf = SparkConf().set("spark.test.home", SPARK_HOME)
         conf = conf.set("spark.driver.resource.gpu.amount", "3")
         conf = conf.set(
             "spark.driver.resource.gpu.discoveryScript", 
self.gpu_discovery_script_file.name
         )
-
-        self.sc = SparkContext("local-cluster[2,2,1024]", class_name, 
conf=conf)
-        self.spark = SparkSession(self.sc)
-        self.mnist_dir_path = tempfile.mkdtemp()
-
-    def tearDown(self) -> None:
-        shutil.rmtree(self.mnist_dir_path)
-        os.unlink(self.gpu_discovery_script_file.name)
-        self.spark.stop()
+        return conf
 
     def setup_env_vars(self, input_map: Dict[str, str]) -> None:
         for key, value in input_map.items():
@@ -336,11 +328,11 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
 
     def test_get_gpus_owned_local(self) -> None:
         addresses = ["0", "1", "2"]
-        self.assertEqual(get_gpus_owned(self.sc), addresses)
+        self.assertEqual(get_gpus_owned(self.spark), addresses)
 
         env_vars = {"CUDA_VISIBLE_DEVICES": "3,4,5"}
         self.setup_env_vars(env_vars)
-        self.assertEqual(get_gpus_owned(self.sc), ["3", "4", "5"])
+        self.assertEqual(get_gpus_owned(self.spark), ["3", "4", "5"])
         self.delete_env_vars(env_vars)
 
     def test_local_training_succeeds(self) -> None:
@@ -382,27 +374,40 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
         output = TorchDistributor(num_processes=2, local_mode=True, 
use_gpu=False).run(
             train_fn, 0.001
         )
-        self.assertEqual(output, "success")
+        self.assertEqual(output, "success" * 4096)
 
 
 @unittest.skipIf(not have_torch, "torch is required")
-class TorchDistributorDistributedUnitTests(unittest.TestCase):
+class TorchDistributorLocalUnitTests(TorchDistributorLocalUnitTestsMixin, 
unittest.TestCase):
     def setUp(self) -> None:
         class_name = self.__class__.__name__
+        conf = self._get_spark_conf()
+        sc = SparkContext("local-cluster[2,2,1024]", class_name, conf=conf)
+        self.spark = SparkSession(sc)
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+
+class TorchDistributorDistributedUnitTestsMixin:
+    def _get_spark_conf(self) -> SparkConf:
         self.gpu_discovery_script_file = 
tempfile.NamedTemporaryFile(delete=False)
         self.gpu_discovery_script_file.write(
             b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": 
[\\"0\\",\\"1\\",\\"2\\"]}'
         )
         self.gpu_discovery_script_file.close()
         # create temporary directory for Worker resources coordination
-        self.tempdir = tempfile.NamedTemporaryFile(delete=False)
-        os.unlink(self.tempdir.name)
+        tempdir = tempfile.NamedTemporaryFile(delete=False)
+        os.unlink(tempdir.name)
         os.chmod(
             self.gpu_discovery_script_file.name,
             stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | 
stat.S_IXOTH,
         )
-        conf = SparkConf().set("spark.test.home", SPARK_HOME)
 
+        conf = SparkConf().set("spark.test.home", SPARK_HOME)
         conf = conf.set(
             "spark.worker.resource.gpu.discoveryScript", 
self.gpu_discovery_script_file.name
         )
@@ -410,15 +415,7 @@ class 
TorchDistributorDistributedUnitTests(unittest.TestCase):
         conf = conf.set("spark.task.cpus", "2")
         conf = conf.set("spark.task.resource.gpu.amount", "1")
         conf = conf.set("spark.executor.resource.gpu.amount", "1")
-
-        self.sc = SparkContext("local-cluster[2,2,1024]", class_name, 
conf=conf)
-        self.spark = SparkSession(self.sc)
-        self.mnist_dir_path = tempfile.mkdtemp()
-
-    def tearDown(self) -> None:
-        shutil.rmtree(self.mnist_dir_path)
-        os.unlink(self.gpu_discovery_script_file.name)
-        self.spark.stop()
+        return conf
 
     def test_dist_training_succeeds(self) -> None:
         CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
@@ -442,13 +439,11 @@ class 
TorchDistributorDistributedUnitTests(unittest.TestCase):
 
         for spark_conf_value, num_processes, expected_output in inputs:
             with self.subTest():
-                self.spark.sparkContext._conf.set(
-                    "spark.task.resource.gpu.amount", str(spark_conf_value)
-                )
+                self.spark.conf.set("spark.task.resource.gpu.amount", 
str(spark_conf_value))
                 distributor = TorchDistributor(num_processes, False, True)
                 self.assertEqual(distributor._get_num_tasks(), expected_output)
 
-        self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount", 
"1")
+        self.spark.conf.set("spark.task.resource.gpu.amount", "1")
 
     def test_distributed_file_with_pytorch(self) -> None:
         test_file_path = "python/test_support/test_pytorch_training_file.py"
@@ -462,11 +457,27 @@ class 
TorchDistributorDistributedUnitTests(unittest.TestCase):
         output = TorchDistributor(num_processes=2, local_mode=False, 
use_gpu=False).run(
             train_fn, 0.001
         )
-        self.assertEqual(output, "success")
+        self.assertEqual(output, "success" * 4096)
 
 
 @unittest.skipIf(not have_torch, "torch is required")
-class TorchWrapperUnitTests(unittest.TestCase):
+class TorchDistributorDistributedUnitTests(
+    TorchDistributorDistributedUnitTestsMixin, unittest.TestCase
+):
+    def setUp(self) -> None:
+        class_name = self.__class__.__name__
+        conf = self._get_spark_conf()
+        sc = SparkContext("local-cluster[2,2,1024]", class_name, conf=conf)
+        self.spark = SparkSession(sc)
+        self.mnist_dir_path = tempfile.mkdtemp()
+
+    def tearDown(self) -> None:
+        shutil.rmtree(self.mnist_dir_path)
+        os.unlink(self.gpu_discovery_script_file.name)
+        self.spark.stop()
+
+
+class TorchWrapperUnitTestsMixin:
     def test_clean_and_terminate(self) -> None:
         def kill_task(task: "subprocess.Popen") -> None:
             time.sleep(1)
@@ -489,6 +500,11 @@ class TorchWrapperUnitTests(unittest.TestCase):
         self.assertEqual(mock_clean_and_terminate.call_count, 0)
 
 
[email protected](not have_torch, "torch is required")
+class TorchWrapperUnitTests(TorchWrapperUnitTestsMixin, unittest.TestCase):
+    pass
+
+
 if __name__ == "__main__":
     from pyspark.ml.torch.tests.test_distributor import *  # noqa: F401,F403
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to