harupy commented on code in PR #39188:
URL: https://github.com/apache/spark/pull/39188#discussion_r1068903058
##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -261,6 +312,130 @@ def __init__(
super().__init__(num_processes, local_mode, use_gpu)
self.ssl_conf = "pytorch.spark.distributor.ignoreSsl" # type: ignore
self._validate_input_params()
+ self.input_params = self._create_input_params()
+
+ @staticmethod
+ def _create_torchrun_command(
+ input_params: dict[str, Any], path_to_train_file: str, *args: Any
+ ) -> list[str]:
+ local_mode = input_params["local_mode"]
+ num_processes = input_params["num_processes"]
+
+ if local_mode:
+ torchrun_args = ["--standalone", "--nnodes=1"]
+ processes_per_node = num_processes
+ else:
+ pass
+ # TODO(SPARK-41592): Handle distributed training
+
+ args_string = list(map(str, args)) # converting all args to strings
+
+ return (
+ [sys.executable, "-m",
"pyspark.ml.torch.distributor.torch_run_process_wrapper"]
+ + torchrun_args
+ + [f"--nproc_per_node={processes_per_node}"]
+ + [path_to_train_file, *args_string]
+ )
+
+ @staticmethod
+ def _execute_command(
+ cmd: list[str], _prctl: bool = True, redirect_to_stdout: bool = True
+ ) -> None:
+ _TAIL_LINES_TO_KEEP = 100
+
+ def sigterm_on_parent_death() -> None:
+ """
+ Uses prctl to automatically send SIGTERM to the command process
when its parent is dead.
+ This handles the case when the parent is a PySpark worker process.
+ If a user cancels the PySpark job, the worker process gets killed,
regardless of
+ PySpark daemon and worker reuse settings.
+ """
+ if _prctl:
+ try:
+ libc = ctypes.CDLL("libc.so.6")
+ # Set the parent process death signal of the command
process to SIGTERM.
+ libc.prctl(1, signal.SIGTERM)
+ except OSError:
+ pass
+
+ task = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ stdin=subprocess.PIPE,
+ env=os.environ,
+ preexec_fn=sigterm_on_parent_death,
+ )
+ task.stdin.close() # type: ignore
+ tail: collections.deque = collections.deque(maxlen=_TAIL_LINES_TO_KEEP)
+ try:
+ for line in task.stdout: # type: ignore
+ decoded = line.decode()
+ tail.append(decoded)
+ if redirect_to_stdout:
+ sys.stdout.write(decoded)
+ task.wait()
+ finally:
+ if task.poll() is None:
+ try:
+ task.terminate() # SIGTERM
+ time.sleep(0.5)
+ if task.poll() is None:
+ task.kill() # SIGKILL
+ except OSError:
+ pass
+ if task.returncode != os.EX_OK:
+ if len(tail) == _TAIL_LINES_TO_KEEP:
+ last_n_msg = f"last {_TAIL_LINES_TO_KEEP} lines of the task
output are"
+ else:
+ last_n_msg = "task output is"
+ task_output = "".join(tail)
+ raise RuntimeError(
+ f"Command {cmd} failed with return code {task.returncode}."
+ f"The {last_n_msg} included below: {task_output}"
+ )
+
+ def _run_local_training(
+ self,
+ framework_wrapper_fn: Optional[Callable],
+ train_object: Union[Callable, str],
+ *args: Any,
+ ) -> Optional[Any]:
+ CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
+ cuda_state_was_set = CUDA_VISIBLE_DEVICES in os.environ
+ old_cuda_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES, "")
+ try:
+ if self.use_gpu:
+ gpus_owned = get_gpus_owned(self.sc)
+
+ if self.num_processes > len(gpus_owned):
+ raise ValueError(
+ f"""{self.num_processes} processes were requested
+ for local training with GPU training but only
+ {len(gpus_owned)} GPUs were available."""
+ )
Review Comment:
This block results in a message like:
```
2 processes ...
for local ...
3 GPUs were ...
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]