mathewjacob1002 commented on code in PR #41770:
URL: https://github.com/apache/spark/pull/41770#discussion_r1256413676
##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -1003,3 +1006,112 @@ def _get_spark_partition_data_loader(
# if num_workers is zero, we cannot set `prefetch_factor` otherwise
# torch will raise error.
return DataLoader(dataset, batch_size, num_workers=num_workers)
+
+
+class DeepspeedTorchDistributor(TorchDistributor):
+
+ def __init__(self, num_gpus: int = 1, nnodes: int = 1, local_mode: bool =
True, use_gpu: bool = True, deepspeed_config = None):
+ """
+ @param: num_gpus: the number of gpus per node (the same num_gpus
argument in deepspeed command)
+ @param: nnodes: the number of nodes that you want to run with
(analagous to deepspeed command)
+ @param: local_mode: boolean value representing whether you want
distributed training or to run the training locally
+ @param: use_gpu: represents whether or not to use GPUs
+ @param: deepspeed_config: can be a dictionary representing
arguments for deepspeed config, or can be a string representing the path
+ to a config file. If nothing is specified, deepspeed will
use its default optimizers and settings
+ """
+ num_processes = num_gpus * nnodes
+ super().__init__(num_processes, local_mode, use_gpu)
+ self.deepspeed_config = deepspeed_config
+ self.ssl_conf = "deepspeed.spark.distributor.ignoreSsl"
+ self._validate_input_params()
+ self.input_params = self._create_input_params()
+ self.cleanup_deepspeed_conf = False
+
+ @staticmethod
+ def _get_deepspeed_config_path(deepspeed_config):
+ if isinstance(deepspeed_config, dict):
+ with tempfile.NamedTemporaryFile(mode='w+', delete=False,
suffix='.json') as fil:
+ json.dump(deepspeed_config, fil)
+ return fil.name
+ deepspeed_config_path = deepspeed_config
+ if deepspeed_config == None:
+ deepspeed_config_path = "" # empty value means the deepspeed will
fall back to default settings
+
+ return deepspeed_config_path
+
+
+ @staticmethod
+ def _get_torchrun_args(local_mode, num_processes):
+ # given the number of processes and the mode, create the torchrun
arguments to use when creating deepspeed command
+ if local_mode:
+ torchrun_args = ["--standalone", "--nnodes=1"]
+ processes_per_node = num_processes
+ return torchrun_args, processes_per_node
+
+ master_addr, master_port = (
+ os.environ["MASTER_ADDR"],
+ os.environ["MASTER_PORT"],
+ )
+ node_rank = os.environ["RANK"]
+ torchrun_args = [
+ f"--nnodes={num_processes}",
+ f"--node_rank={node_rank}",
+ f"--rdzv_endpoint={master_addr}:{master_port}",
+ "--rdzv_id=0",
+ ]
+ processes_per_node = 1
+ return torchrun_args, processes_per_node
+
+ @staticmethod
+ def _create_torchrun_command(
+ input_params: Dict[str, Any], train_path: str, *args: Any) ->
List[str]:
+ local_mode = input_params["local_mode"]
+ num_processes = input_params["num_processes"]
+ deepspeed_config = input_params["deepspeed_config"]
+
+ deepspeed_config_path =
DeepspeedTorchDistributor._get_deepspeed_config_path(deepspeed_config)
+
+
+ torchrun_args, processes_per_node =
DeepspeedTorchDistributor._get_torchrun_args(local_mode, num_processes)
+
+ args_string = list(map(str, args))
+
+ command_to_run = [
+ sys.executable,
+ "-m",
+ "torch.distributed.run",
+ *torchrun_args,
+ f"--nproc_per_node={processes_per_node}",
+ train_path,
+ *args_string,
+ "-deepspeed",
+ "--deepspeed_config",
+ deepspeed_config_path
+ ]
+ return command_to_run
+
+
+ @staticmethod
+ def _run_training_on_pytorch_file(input_params: Dict[str, Any],
train_path: str, *args: Any, **kwargs : Any) -> None :
+ if kwargs:
+ raise ValueError("DeepspeedTorchDistributor with pytorch file
doesn't support key-word type arguments")
+
+ log_streaming_client = input_params.get("log_streaming_client", None)
+ training_command =
DeepspeedTorchDistributor._create_torchrun_command(input_params, train_path,
*args)
+ DeepspeedTorchDistributor._execute_command(training_command,
log_streaming_client=log_streaming_client)
+
+ def run(self, train_object: Union[Callable, str], *args : Any, **kwargs:
Any) -> Optional[Any]:
+ # if the "train_object" is a string, then we assume it's a filepath.
Otherwise, we assume it's a function
+ if isinstance(train_object, str):
+ framework_wrapper_fn =
DeepspeedTorchDistributor._run_training_on_pytorch_file
+ else:
+ raise RuntimeError("Work in progress; not supported atm")
+ framework_wrapper_fn =
TorchDistributor._run_training_on_pytorch_file
+ if self.local_mode:
+ output = self._run_local_training(framework_wrapper_fn,
train_object, *args, **kwargs)
+ else:
+ output = self._run_distributed_training(framework_wrapper_fn,
train_object, None, *args, **kwargs)
Review Comment:
Done! Made it a named arg.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]