WeichenXu123 commented on code in PR #39369:
URL: https://github.com/apache/spark/pull/39369#discussion_r1082443370


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -495,32 +546,119 @@ def set_gpus(context: "BarrierTaskContext") -> None:
 
     def _run_distributed_training(
         self,
-        framework_wrapper_fn: Optional[Callable],
+        framework_wrapper_fn: Callable,
         train_object: Union[Callable, str],
         *args: Any,
     ) -> Optional[Any]:
         if not framework_wrapper_fn:
             raise RuntimeError("Unknown combination of parameters")
+
+        log_streaming_server = LogStreamingServer()
+        self.driver_address = get_driver_host(self.sc)
+        log_streaming_server.start()
+        time.sleep(1)  # wait for the server to start
+        self.log_streaming_server_port = log_streaming_server.port
+
         spark_task_function = self._get_spark_task_function(
             framework_wrapper_fn, train_object, *args
         )
         self._check_encryption()
-        result = (
-            self.sc.parallelize(range(self.num_tasks), self.num_tasks)
-            .barrier()
-            .mapPartitions(spark_task_function)
-            .collect()[0]
+        self.logger.info(
+            f"Started distributed training with {self.num_processes} executor 
proceses"
+        )
+        try:
+            result = (
+                self.sc.parallelize(range(self.num_tasks), self.num_tasks)
+                .barrier()
+                .mapPartitions(spark_task_function)
+                .collect()[0]
+            )
+        finally:
+            log_streaming_server.shutdown()
+        self.logger.info(
+            f"Finished distributed training with {self.num_processes} executor 
proceses"
         )
         return result
 
     @staticmethod
     def _run_training_on_pytorch_file(
         input_params: Dict[str, Any], train_path: str, *args: Any
     ) -> None:
+        log_streaming_client = input_params.get("log_streaming_client", None)
         training_command = TorchDistributor._create_torchrun_command(
             input_params, train_path, *args
         )
-        TorchDistributor._execute_command(training_command)
+        TorchDistributor._execute_command(
+            training_command, log_streaming_client=log_streaming_client
+        )
+
+    @contextmanager
+    @staticmethod
+    def _setup_files(train_fn: Callable, *args: Any) -> Tuple[str, str]:
+        save_dir = TorchDistributor._create_save_dir()
+        pickle_file_path = TorchDistributor._save_pickled_function(save_dir, 
train_fn, *args)
+        output_file_path = os.path.join(save_dir, 
TorchDistributor.PICKLED_OUTPUT_FILE)
+        train_file_path = TorchDistributor._create_torchrun_train_file(
+            save_dir, pickle_file_path, output_file_path
+        )
+        try:
+            yield (train_file_path, output_file_path)
+        finally:
+            TorchDistributor._cleanup_files(save_dir)
+
+    @staticmethod
+    def _run_training_on_pytorch_function(
+        input_params: dict[str, Any], train_fn: Callable, *args: Any  # TODO: 
change dict to Dict
+    ) -> Any:
+        with TorchDistributor._setup_files(train_fn, *args) as 
(train_file_path, output_file_path):
+            args = []  # type: ignore
+            TorchDistributor._run_training_on_pytorch_file(input_params, 
train_file_path, *args)
+            output = TorchDistributor._get_pickled_output(output_file_path)
+        return output
+
+    @staticmethod
+    def _create_save_dir() -> str:
+        # TODO: need to do this in a safe way to avoid issues during 
concurrent runs
+        return tempfile.mkdtemp()
+
+    @staticmethod
+    def _cleanup_files(save_dir: str) -> None:
+        shutil.rmtree(save_dir)

Review Comment:
   pls add argument `ignore_errors=True`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to