maddiedawson commented on code in PR #41770:
URL: https://github.com/apache/spark/pull/41770#discussion_r1258882891
##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -393,30 +394,50 @@ def __init__(
self._validate_input_params()
self.input_params = self._create_input_params()
+
+ @staticmethod
+ def _get_torchrun_args(local_mode: bool, num_processes: int) ->
Tuple[List[Any], int]:
+ """
+ Given the mode and the number of processes, create the arguments to be
given to for torch
+
+ Parameters
+ ---------
+ local_mode: bool
+ Whether or not we are running training locally or in a distributed
fashion
+
+ num_processes: int
+ The number of processes that we are going to use
+
+ Returns
+ ------
+ Tuple[List[Any], int]
+ A tuple containing a list of arguments to pass as pytorch args ,
as well as the number of processes per node
+ """
+ if local_mode:
+ torchrun_args = ["--standalone","--nnodes=1"]
+ processes_per_node = num_processes
+ return torchrun_args, processes_per_node
+
+ master_addr = os.environ["MASTER_ADDR"]
+ master_port = os.environ["MASTER_PORT"]
+ node_rank = os.environ["RANK"]
+ torchrun_args = [
+ f"--nnodes={num_processes}",
+ f"--node_rank={node_rank}",
+ f"--rdzv_endpoint={master_addr}:{master_port}",
+ "--rdzv_id=0",
+ ]
+ processes_per_node = 1
+ return torchrun_args, processes_per_node
+
@staticmethod
def _create_torchrun_command(
input_params: Dict[str, Any], path_to_train_file: str, *args: Any
) -> List[str]:
local_mode = input_params["local_mode"]
num_processes = input_params["num_processes"]
-
- if local_mode:
- torchrun_args = ["--standalone", "--nnodes=1"]
- processes_per_node = num_processes
- else:
- master_addr, master_port = (
- os.environ["MASTER_ADDR"],
- os.environ["MASTER_PORT"],
- )
- node_rank = os.environ["RANK"]
- torchrun_args = [
- f"--nnodes={num_processes}",
- f"--node_rank={node_rank}",
- f"--rdzv_endpoint={master_addr}:{master_port}",
- "--rdzv_id=0",
- ] # TODO: setup random ID that is gleaned from env variables
Review Comment:
Let's keep the TODO in the new function
##########
python/pyspark/ml/torch/deepspeed/tests/test_deepspeed_distributor.py:
##########
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import sys
+from typing import List, Any, Callable, Dict
+import unittest
+
+from pyspark.ml.torch.deepspeed.deepspeed_distributor import
DeepspeedTorchDistributor
+
+class DeepspeedTorchDistributorUnitTests(unittest.TestCase):
+
+ def _get_env_var(self, var_name: str, default_value: Any) -> Any:
+ value = os.getenv(var_name)
+ if value:
+ return value
+ os.environ[var_name] = str(default_value)
+ value = default_value
+ return value
+
+ def _get_env_variables_distributed(self):
+ MASTER_ADDR = self._get_env_var("MASTER_ADDR", "127.0.0.1")
+ MASTER_PORT = self._get_env_var("MASTER_PORT", 2000)
+ RANK = self._get_env_var("RANK", 0)
+ return MASTER_ADDR, MASTER_PORT, RANK
+
+ def test_get_torchrun_args_local(self):
+ number_of_processes = 5
+ EXPECTED_TORCHRUN_ARGS_LOCAL= [
+ "--standalone", "--nnodes=1"
+ ]
+ EXPECTED_PROCESSES_PER_NODE_LOCAL = number_of_processes
+ get_local_mode_torchrun_args, process_per_node=
DeepspeedTorchDistributor._get_torchrun_args(True, number_of_processes)
+
self.assertEqual(get_local_mode_torchrun_args,EXPECTED_TORCHRUN_ARGS_LOCAL)
+ self.assertEqual(EXPECTED_PROCESSES_PER_NODE_LOCAL,process_per_node)
+
+ def test_get_torchrun_args_distributed(self):
+ number_of_processes = 5
+ MASTER_ADDR, MASTER_PORT, RANK = self._get_env_variables_distributed()
+ EXPECTED_TORCHRUN_ARGS_DISTRIBUTED = [
+ f"--nnodes={number_of_processes}",
+ f"--node_rank={RANK}",
+ f"--rdzv_endpoint={MASTER_ADDR}:{MASTER_PORT}",
+ "--rdzv_id=0"
+ ]
+ torchrun_args_distributed, process_per_node =
DeepspeedTorchDistributor._get_torchrun_args(False, number_of_processes)
+
self.assertEqual(torchrun_args_distributed,EXPECTED_TORCHRUN_ARGS_DISTRIBUTED)
+ self.assertEqual(process_per_node,1)
+
+ def test_create_torchrun_command_local(self):
+ DEEPSPEED_CONF = "path/to/deepspeed"
+ TRAIN_FILE_PATH = "path/to/exec"
+ NUM_PROCS = 10
+ input_params = {}
+ input_params["local_mode"] = True
+ input_params['num_processes'] = NUM_PROCS
+ input_params["deepspeed_config"] = DEEPSPEED_CONF
+
+ # get the arguments for no argument, local run
+ torchrun_local_args_expected = ["--standalone", "--nnodes=1"]
+ with self.subTest(msg="Testing local training with no extra args"):
+ LOCAL_CMD_NO_ARGS_EXPECTED= [
+ sys.executable,
+ "-m",
+ "torch.distributed.run",
+ *torchrun_local_args_expected,
+ f"--nproc_per_node={NUM_PROCS}",
+ TRAIN_FILE_PATH,
+ "-deepspeed",
+ "--deepspeed_config",
+ DEEPSPEED_CONF
+ ]
+ local_cmd =
DeepspeedTorchDistributor._create_torchrun_command(input_params,
TRAIN_FILE_PATH)
+ self.assertEqual(local_cmd,LOCAL_CMD_NO_ARGS_EXPECTED)
+ with self.subTest(msg="Testing local training with extra args for the
training script"):
+ local_mode_version_args = ["--arg1", "--arg2"]
+ LOCAL_CMD_ARGS_EXPECTED= [
+ sys.executable,
+ "-m",
+ "torch.distributed.run",
+ *torchrun_local_args_expected,
+ f"--nproc_per_node={NUM_PROCS}",
+ TRAIN_FILE_PATH,
+ *local_mode_version_args,
+ "-deepspeed",
+ "--deepspeed_config",
+ DEEPSPEED_CONF
+ ]
+
+ local_cmd_with_args =
DeepspeedTorchDistributor._create_torchrun_command(input_params,
TRAIN_FILE_PATH, *local_mode_version_args)
+ self.assertEqual(local_cmd_with_args,LOCAL_CMD_ARGS_EXPECTED)
+
+ def test_create_torchrun_command_distributed(self):
+ DEEPSPEED_CONF = "path/to/deepspeed"
+ TRAIN_FILE_PATH = "path/to/exec"
+ NUM_PROCS = 10
+ input_params = {}
+ input_params["local_mode"] = True
+ input_params['num_processes'] = NUM_PROCS
+ input_params["deepspeed_config"] = DEEPSPEED_CONF
+ # distributed training environment
+ distributed_master_address, distributed_master_port, distributed_rank
= self._get_env_variables_distributed()
+ distributed_torchrun_args = [
+ f"--nnodes={NUM_PROCS}",
+ f"--node_rank={distributed_rank}",
+
f"--rdzv_endpoint={distributed_master_address}:{distributed_master_port}",
+ "--rdzv_id=0",
+ ]
+ DISTRIBUTED_CMD_NO_ARGS_EXPECTED = [sys.executable,
+ "-m",
+ "torch.distributed.run",
+ *distributed_torchrun_args,
+ "--nproc_per_node=1",
+ TRAIN_FILE_PATH,
+ "-deepspeed",
+ "--deepspeed_config",
+ DEEPSPEED_CONF
+ ]
+ # test distributed training without arguments
+ input_params["local_mode"] = False
+ distributed_command =
DeepspeedTorchDistributor._create_torchrun_command(input_params,
TRAIN_FILE_PATH)
+ self.assertEqual(DISTRIBUTED_CMD_NO_ARGS_EXPECTED,distributed_command)
+ # test distributed training with random arguments
+ distributed_extra_args = ["-args1", "--args2"]
+ DISTRIBUTED_CMD_ARGS_EXPECTED = [sys.executable,
+ "-m",
+ "torch.distributed.run",
+ *distributed_torchrun_args,
+ "--nproc_per_node=1",
+ TRAIN_FILE_PATH,
+ *distributed_extra_args,
+ "-deepspeed",
+ "--deepspeed_config",
+ DEEPSPEED_CONF
+ ]
+ print("The distributed training command: ",
DISTRIBUTED_CMD_ARGS_EXPECTED)
+ distributed_command_with_args =
DeepspeedTorchDistributor._create_torchrun_command(input_params,
TRAIN_FILE_PATH,*distributed_extra_args)
+
self.assertEqual(DISTRIBUTED_CMD_ARGS_EXPECTED,distributed_command_with_args)
Review Comment:
It's inconsistent that there are sub-tests above for args vs no args but not
here.
##########
python/pyspark/ml/torch/deepspeed/deepspeed_distributor.py:
##########
@@ -0,0 +1,144 @@
+
+import json
+import os
+import sys
+import tempfile
+from typing import (
+ Union,
+ Callable,
+ List,
+ Dict,
+ Optional,
+ Any,
+ Tuple,
+)
+
+from pyspark.ml.torch.distributor import TorchDistributor
+
+class DeepspeedTorchDistributor(TorchDistributor):
+
+ def __init__(self, num_gpus: int = 1, nnodes: int = 1, local_mode: bool =
True, use_gpu: bool = True, deepspeed_config = None):
+ """
+ This class is used to run deepspeed training workloads with spark
clusters. The user has the option to
+ specify the number of gpus per node and the number of nodes (the
same as if running from terminal),
+ as well as specify a deepspeed configuration file.
+
+ Parameters
+ ----------
+ num_gpus: int
+ The number of GPUs to use per node (analagous to num_gpus in
deepspeed command).
+
+ nnodes: int
+ The number of nodes that should be used for the run.
+
+ local_mode: bool
+ Whether or not to run the training in a distributed fashion or
just locally.
+
+ use_gpu: bool
+ Boolean flag to determine whether to utilize gpus.
+
+ deepspeed_config: Union[Dict[str,Any], str] or None:
+ The configuration file to be used for launching the deepspeed
application.
+ If it is a dictionary mapping parameters to values, then we
will create the file.
+ If None, deepspeed will fall back to default parameters.
+ """
+ num_processes = num_gpus * nnodes
+ super().__init__(num_processes, local_mode, use_gpu)
+ self.deepspeed_config = deepspeed_config
+ self.ssl_conf = "deepspeed.spark.distributor.ignoreSsl"
+ self._validate_input_params()
+ self.input_params = self._create_input_params()
Review Comment:
Then we can remove the duplicate _validate_input_params and
_create_input_params calls
##########
python/pyspark/ml/torch/deepspeed/deepspeed_distributor.py:
##########
@@ -0,0 +1,127 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import json
+import os
+import sys
+import tempfile
+from typing import (
+ Union,
+ Callable,
+ List,
+ Dict,
+ Optional,
+ Any,
+ Tuple,
+)
+
+from pyspark.ml.torch.distributor import TorchDistributor
+
+class DeepspeedTorchDistributor(TorchDistributor):
+
+ def __init__(self, num_gpus: int = 1,
+ nnodes: int = 1,
+ local_mode: bool = True,
+ use_gpu: bool = True,
+ deepspeed_config = None):
+ """
+ This class is used to run deepspeed training workloads with spark
clusters. The user has the option to
+ specify the number of gpus per node and the number of nodes (the
same as if running from terminal),
+ as well as specify a deepspeed configuration file.
+
+ Parameters
+ ----------
+ num_gpus: int
+ The number of GPUs to use per node (analagous to num_gpus in
deepspeed command).
+
+ nnodes: int
+ The number of nodes that should be used for the run.
+
+ local_mode: bool
+ Whether or not to run the training in a distributed fashion or
just locally.
+
+ use_gpu: bool
+ Boolean flag to determine whether to utilize gpus.
+
+ deepspeed_config: Union[Dict[str,Any], str] or None:
+ The configuration file to be used for launching the deepspeed
application.
+ If it is a dictionary mapping parameters to values, then we
will create the file.
+ If None, deepspeed will fall back to default parameters.
+ """
+ num_processes = num_gpus * nnodes
+ super().__init__(num_processes, local_mode, use_gpu)
+ self.deepspeed_config = deepspeed_config
+ self.ssl_conf = "deepspeed.spark.distributor.ignoreSsl"
+ self._validate_input_params()
+ self.input_params = self._create_input_params()
+ self.cleanup_deepspeed_conf = False
+
+ @staticmethod
+ def _get_deepspeed_config_path(deepspeed_config):
+ if isinstance(deepspeed_config, dict):
+ with tempfile.NamedTemporaryFile(mode='w+', delete=False,
suffix='.json') as fil:
+ json.dump(deepspeed_config, fil)
+ return fil.name
+ deepspeed_config_path = deepspeed_config
+ # Empty value means the deepspeed will fall back to default settings.
+ if deepspeed_config == None:
+ return ""
+ return deepspeed_config_path
+
+ @staticmethod
+ def _create_torchrun_command(
+ input_params: Dict[str, Any], train_path: str, *args: Any) ->
List[str]:
+ local_mode = input_params["local_mode"]
+ num_processes = input_params["num_processes"]
+ deepspeed_config = input_params["deepspeed_config"]
+ deepspeed_config_path =
DeepspeedTorchDistributor._get_deepspeed_config_path(deepspeed_config)
+ torchrun_args, processes_per_node =
TorchDistributor._get_torchrun_args(local_mode, num_processes)
+ args_string = list(map(str, args))
+ command_to_run = [
+ sys.executable,
+ "-m",
+ "torch.distributed.run",
+ *torchrun_args,
+ f"--nproc_per_node={processes_per_node}",
+ train_path,
+ *args_string,
+ "-deepspeed",
+ "--deepspeed_config",
+ deepspeed_config_path
+ ]
+ return command_to_run
+
+ @staticmethod
+ def _run_training_on_pytorch_file(input_params: Dict[str, Any],
train_path: str, *args: Any, **kwargs : Any) -> None :
+ if kwargs:
+ raise ValueError("DeepspeedTorchDistributor with pytorch file
doesn't support key-word type arguments")
+
+ log_streaming_client = input_params.get("log_streaming_client", None)
+ training_command =
DeepspeedTorchDistributor._create_torchrun_command(input_params, train_path,
*args)
+ DeepspeedTorchDistributor._execute_command(training_command,
log_streaming_client=log_streaming_client)
+
+ def run(self, train_object: Union[Callable, str], *args : Any, **kwargs:
Any) -> Optional[Any]:
+ # If the "train_object" is a string, then we assume it's a filepath.
Otherwise, we assume it's a function.
+ if isinstance(train_object, str):
+ framework_wrapper_fn =
DeepspeedTorchDistributor._run_training_on_pytorch_file
+ else:
+ raise RuntimeError("Work in progress; not supported atm")
+
+ if self.local_mode:
+ output = self._run_local_training(framework_wrapper_fn,
train_object, *args, **kwargs)
Review Comment:
You don't need to assign output if you are immediately returning it. Same
below
##########
python/pyspark/ml/torch/deepspeed/tests/test_deepspeed_distributor.py:
##########
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import sys
+from typing import List, Any, Callable, Dict
+import unittest
+
+from pyspark.ml.torch.deepspeed.deepspeed_distributor import
DeepspeedTorchDistributor
+
+class DeepspeedTorchDistributorUnitTests(unittest.TestCase):
+
+ def _get_env_var(self, var_name: str, default_value: Any) -> Any:
+ value = os.getenv(var_name)
+ if value:
+ return value
+ os.environ[var_name] = str(default_value)
+ value = default_value
Review Comment:
You don't need to assign value if you are immediately returning it
##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -40,6 +40,7 @@
Iterator,
)
+
Review Comment:
Undo this change
##########
python/pyspark/ml/torch/deepspeed/deepspeed_distributor.py:
##########
@@ -0,0 +1,127 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import json
+import os
+import sys
+import tempfile
+from typing import (
+ Union,
+ Callable,
+ List,
+ Dict,
+ Optional,
+ Any,
+ Tuple,
+)
+
+from pyspark.ml.torch.distributor import TorchDistributor
+
+class DeepspeedTorchDistributor(TorchDistributor):
+
+ def __init__(self, num_gpus: int = 1,
+ nnodes: int = 1,
+ local_mode: bool = True,
+ use_gpu: bool = True,
+ deepspeed_config = None):
+ """
+ This class is used to run deepspeed training workloads with spark
clusters. The user has the option to
+ specify the number of gpus per node and the number of nodes (the
same as if running from terminal),
+ as well as specify a deepspeed configuration file.
+
+ Parameters
+ ----------
+ num_gpus: int
+ The number of GPUs to use per node (analagous to num_gpus in
deepspeed command).
+
+ nnodes: int
+ The number of nodes that should be used for the run.
+
+ local_mode: bool
+ Whether or not to run the training in a distributed fashion or
just locally.
+
+ use_gpu: bool
+ Boolean flag to determine whether to utilize gpus.
+
+ deepspeed_config: Union[Dict[str,Any], str] or None:
+ The configuration file to be used for launching the deepspeed
application.
+ If it is a dictionary mapping parameters to values, then we
will create the file.
+ If None, deepspeed will fall back to default parameters.
+ """
+ num_processes = num_gpus * nnodes
+ super().__init__(num_processes, local_mode, use_gpu)
+ self.deepspeed_config = deepspeed_config
+ self.ssl_conf = "deepspeed.spark.distributor.ignoreSsl"
+ self._validate_input_params()
+ self.input_params = self._create_input_params()
+ self.cleanup_deepspeed_conf = False
+
+ @staticmethod
+ def _get_deepspeed_config_path(deepspeed_config):
+ if isinstance(deepspeed_config, dict):
+ with tempfile.NamedTemporaryFile(mode='w+', delete=False,
suffix='.json') as fil:
+ json.dump(deepspeed_config, fil)
+ return fil.name
+ deepspeed_config_path = deepspeed_config
+ # Empty value means the deepspeed will fall back to default settings.
+ if deepspeed_config == None:
+ return ""
+ return deepspeed_config_path
+
+ @staticmethod
+ def _create_torchrun_command(
+ input_params: Dict[str, Any], train_path: str, *args: Any) ->
List[str]:
+ local_mode = input_params["local_mode"]
+ num_processes = input_params["num_processes"]
+ deepspeed_config = input_params["deepspeed_config"]
+ deepspeed_config_path =
DeepspeedTorchDistributor._get_deepspeed_config_path(deepspeed_config)
+ torchrun_args, processes_per_node =
TorchDistributor._get_torchrun_args(local_mode, num_processes)
+ args_string = list(map(str, args))
+ command_to_run = [
+ sys.executable,
+ "-m",
+ "torch.distributed.run",
+ *torchrun_args,
+ f"--nproc_per_node={processes_per_node}",
+ train_path,
+ *args_string,
+ "-deepspeed",
+ "--deepspeed_config",
+ deepspeed_config_path
+ ]
+ return command_to_run
+
+ @staticmethod
+ def _run_training_on_pytorch_file(input_params: Dict[str, Any],
train_path: str, *args: Any, **kwargs : Any) -> None :
+ if kwargs:
+ raise ValueError("DeepspeedTorchDistributor with pytorch file
doesn't support key-word type arguments")
+
+ log_streaming_client = input_params.get("log_streaming_client", None)
+ training_command =
DeepspeedTorchDistributor._create_torchrun_command(input_params, train_path,
*args)
+ DeepspeedTorchDistributor._execute_command(training_command,
log_streaming_client=log_streaming_client)
+
+ def run(self, train_object: Union[Callable, str], *args : Any, **kwargs:
Any) -> Optional[Any]:
+ # If the "train_object" is a string, then we assume it's a filepath.
Otherwise, we assume it's a function.
+ if isinstance(train_object, str):
+ framework_wrapper_fn =
DeepspeedTorchDistributor._run_training_on_pytorch_file
+ else:
+ raise RuntimeError("Work in progress; not supported atm")
Review Comment:
This error is not very helpful; it doesn't say what is not supported. Let's
change it to say "Running DeepspeedTorchDistributor on Python functions is not
currently supported"
##########
python/pyspark/ml/torch/deepspeed/deepspeed_distributor.py:
##########
@@ -0,0 +1,144 @@
+
+import json
+import os
+import sys
+import tempfile
+from typing import (
+ Union,
+ Callable,
+ List,
+ Dict,
+ Optional,
+ Any,
+ Tuple,
+)
+
+from pyspark.ml.torch.distributor import TorchDistributor
+
+class DeepspeedTorchDistributor(TorchDistributor):
+
+ def __init__(self, num_gpus: int = 1, nnodes: int = 1, local_mode: bool =
True, use_gpu: bool = True, deepspeed_config = None):
+ """
+ This class is used to run deepspeed training workloads with spark
clusters. The user has the option to
+ specify the number of gpus per node and the number of nodes (the
same as if running from terminal),
+ as well as specify a deepspeed configuration file.
+
+ Parameters
+ ----------
+ num_gpus: int
+ The number of GPUs to use per node (analagous to num_gpus in
deepspeed command).
+
+ nnodes: int
+ The number of nodes that should be used for the run.
+
+ local_mode: bool
+ Whether or not to run the training in a distributed fashion or
just locally.
+
+ use_gpu: bool
+ Boolean flag to determine whether to utilize gpus.
+
+ deepspeed_config: Union[Dict[str,Any], str] or None:
+ The configuration file to be used for launching the deepspeed
application.
+ If it is a dictionary mapping parameters to values, then we
will create the file.
+ If None, deepspeed will fall back to default parameters.
+ """
+ num_processes = num_gpus * nnodes
+ super().__init__(num_processes, local_mode, use_gpu)
+ self.deepspeed_config = deepspeed_config
+ self.ssl_conf = "deepspeed.spark.distributor.ignoreSsl"
+ self._validate_input_params()
+ self.input_params = self._create_input_params()
Review Comment:
Oh I see. Yes let's pass it into the ctor as an optional arg
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]