mathewjacob1002 commented on code in PR #41770:
URL: https://github.com/apache/spark/pull/41770#discussion_r1258664693
##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -543,6 +545,126 @@ def test_check_parent_alive(self,
mock_clean_and_terminate: Callable) -> None:
class TorchWrapperUnitTests(TorchWrapperUnitTestsMixin, unittest.TestCase):
pass
+class DeepspeedTorchDistributorUnitTests(unittest.TestCase):
+
+ def _get_env_var(self, var_name: str, default_value: Any) -> Any:
+ value = os.getenv(var_name)
+ if value:
+ return value
+ else:
+ os.environ[var_name] = default_value
+ value = default_value
+ return value
+
+ def _get_env_variables_distributed(self):
+ MASTER_ADDR = self._get_env_var("MASTER_ADDR", "127.0.0.1")
+ MASTER_PORT = self._get_env_var("MASTER_PORT", 2000)
+ RANK = self._get_env_var("RANK", 0)
+ return MASTER_ADDR, MASTER_PORT, RANK
+
+
+
+ def test_get_torchrun_args(self):
+ number_of_processes = 5
+ EXPECTED_TORCHRUN_ARGS_LOCAL= [
+ "--standalone", "--nnodes=1"
+ ]
+ EXPECTED_PROCESSES_PER_NODE_LOCAL = number_of_processes
+
+
+ get_local_mode_torchrun_args, process_per_node=
DeepspeedTorchDistributor._get_torchrun_args(True, number_of_processes)
+ assert(get_local_mode_torchrun_args == EXPECTED_TORCHRUN_ARGS_LOCAL)
+ assert(EXPECTED_PROCESSES_PER_NODE_LOCAL == process_per_node)
+ MASTER_ADDR, MASTER_PORT, RANK = self._get_env_variables_distributed()
+ EXPECTED_TORCHRUN_ARGS_DISTRIBUTED = [
+ f"--nnodes={number_of_processes}",
+ f" --node_rank={RANK}",
+ f" --rdzv_endpoint={MASTER_ADDR}:{MASTER_PORT}",
+ "--rdzv_id=0"
+ ]
+ torchrun_args_distributed, process_per_node =
DeepspeedTorchDistributor._get_torchrun_args(False, number_of_processes)
+ assert(torchrun_args_distributed == EXPECTED_TORCHRUN_ARGS_DISTRIBUTED)
+ assert(process_per_node == 1)
Review Comment:
Done!
##########
python/pyspark/ml/torch/tests/test_distributor.py:
##########
@@ -543,6 +545,126 @@ def test_check_parent_alive(self,
mock_clean_and_terminate: Callable) -> None:
class TorchWrapperUnitTests(TorchWrapperUnitTestsMixin, unittest.TestCase):
pass
+class DeepspeedTorchDistributorUnitTests(unittest.TestCase):
+
+ def _get_env_var(self, var_name: str, default_value: Any) -> Any:
+ value = os.getenv(var_name)
+ if value:
+ return value
+ else:
+ os.environ[var_name] = default_value
+ value = default_value
+ return value
+
+ def _get_env_variables_distributed(self):
+ MASTER_ADDR = self._get_env_var("MASTER_ADDR", "127.0.0.1")
+ MASTER_PORT = self._get_env_var("MASTER_PORT", 2000)
+ RANK = self._get_env_var("RANK", 0)
+ return MASTER_ADDR, MASTER_PORT, RANK
+
+
+
+ def test_get_torchrun_args(self):
+ number_of_processes = 5
+ EXPECTED_TORCHRUN_ARGS_LOCAL= [
+ "--standalone", "--nnodes=1"
+ ]
+ EXPECTED_PROCESSES_PER_NODE_LOCAL = number_of_processes
+
+
+ get_local_mode_torchrun_args, process_per_node=
DeepspeedTorchDistributor._get_torchrun_args(True, number_of_processes)
+ assert(get_local_mode_torchrun_args == EXPECTED_TORCHRUN_ARGS_LOCAL)
+ assert(EXPECTED_PROCESSES_PER_NODE_LOCAL == process_per_node)
+ MASTER_ADDR, MASTER_PORT, RANK = self._get_env_variables_distributed()
+ EXPECTED_TORCHRUN_ARGS_DISTRIBUTED = [
+ f"--nnodes={number_of_processes}",
+ f" --node_rank={RANK}",
+ f" --rdzv_endpoint={MASTER_ADDR}:{MASTER_PORT}",
+ "--rdzv_id=0"
+ ]
+ torchrun_args_distributed, process_per_node =
DeepspeedTorchDistributor._get_torchrun_args(False, number_of_processes)
+ assert(torchrun_args_distributed == EXPECTED_TORCHRUN_ARGS_DISTRIBUTED)
+ assert(process_per_node == 1)
+
+
+ def test_create_torchrun_command(self):
Review Comment:
Done!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]