WeichenXu123 commented on code in PR #39369: URL: https://github.com/apache/spark/pull/39369#discussion_r1082443370
########## python/pyspark/ml/torch/distributor.py: ########## @@ -495,32 +546,119 @@ def set_gpus(context: "BarrierTaskContext") -> None: def _run_distributed_training( self, - framework_wrapper_fn: Optional[Callable], + framework_wrapper_fn: Callable, train_object: Union[Callable, str], *args: Any, ) -> Optional[Any]: if not framework_wrapper_fn: raise RuntimeError("Unknown combination of parameters") + + log_streaming_server = LogStreamingServer() + self.driver_address = get_driver_host(self.sc) + log_streaming_server.start() + time.sleep(1) # wait for the server to start + self.log_streaming_server_port = log_streaming_server.port + spark_task_function = self._get_spark_task_function( framework_wrapper_fn, train_object, *args ) self._check_encryption() - result = ( - self.sc.parallelize(range(self.num_tasks), self.num_tasks) - .barrier() - .mapPartitions(spark_task_function) - .collect()[0] + self.logger.info( + f"Started distributed training with {self.num_processes} executor proceses" + ) + try: + result = ( + self.sc.parallelize(range(self.num_tasks), self.num_tasks) + .barrier() + .mapPartitions(spark_task_function) + .collect()[0] + ) + finally: + log_streaming_server.shutdown() + self.logger.info( + f"Finished distributed training with {self.num_processes} executor proceses" ) return result @staticmethod def _run_training_on_pytorch_file( input_params: Dict[str, Any], train_path: str, *args: Any ) -> None: + log_streaming_client = input_params.get("log_streaming_client", None) training_command = TorchDistributor._create_torchrun_command( input_params, train_path, *args ) - TorchDistributor._execute_command(training_command) + TorchDistributor._execute_command( + training_command, log_streaming_client=log_streaming_client + ) + + @contextmanager + @staticmethod + def _setup_files(train_fn: Callable, *args: Any) -> Tuple[str, str]: + save_dir = TorchDistributor._create_save_dir() + pickle_file_path = TorchDistributor._save_pickled_function(save_dir, train_fn, *args) + output_file_path = os.path.join(save_dir, TorchDistributor.PICKLED_OUTPUT_FILE) + train_file_path = TorchDistributor._create_torchrun_train_file( + save_dir, pickle_file_path, output_file_path + ) + try: + yield (train_file_path, output_file_path) + finally: + TorchDistributor._cleanup_files(save_dir) + + @staticmethod + def _run_training_on_pytorch_function( + input_params: dict[str, Any], train_fn: Callable, *args: Any # TODO: change dict to Dict + ) -> Any: + with TorchDistributor._setup_files(train_fn, *args) as (train_file_path, output_file_path): + args = [] # type: ignore + TorchDistributor._run_training_on_pytorch_file(input_params, train_file_path, *args) + output = TorchDistributor._get_pickled_output(output_file_path) + return output + + @staticmethod + def _create_save_dir() -> str: + # TODO: need to do this in a safe way to avoid issues during concurrent runs + return tempfile.mkdtemp() + + @staticmethod + def _cleanup_files(save_dir: str) -> None: + shutil.rmtree(save_dir) Review Comment: pls add argument `ignore_errors=True` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org