This is an automated email from the ASF dual-hosted git repository. weichenxu123 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e1619653895 [SPARK-41593][FOLLOW-UP] Fix the case torch distributor logging server not shut down e1619653895 is described below commit e1619653895b4d5e11d7121bdb7906355d8c17bf Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Tue May 30 19:13:20 2023 +0800 [SPARK-41593][FOLLOW-UP] Fix the case torch distributor logging server not shut down ### What changes were proposed in this pull request? Fix the case torch distributor logging server not shut down. The `_get_spark_task_function` and `_check_encryption` might raise exception, in this case, the logging server must be shut down but it is not shut down. This PR fixes the case. ### Why are the changes needed? Fix the case torch distributor logging server not shut down ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests. Closes #41375 from WeichenXu123/improve-torch-distributor-log-server-exception-handling. Authored-by: Weichen Xu <weichen...@databricks.com> Signed-off-by: Weichen Xu <weichen...@databricks.com> --- python/pyspark/ml/torch/distributor.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py index ad8b4d8cc25..0249e6b4b2c 100644 --- a/python/pyspark/ml/torch/distributor.py +++ b/python/pyspark/ml/torch/distributor.py @@ -665,20 +665,20 @@ class TorchDistributor(Distributor): time.sleep(1) # wait for the server to start self.log_streaming_server_port = log_streaming_server.port - spark_task_function = self._get_spark_task_function( - framework_wrapper_fn, train_object, spark_dataframe, *args, **kwargs - ) - self._check_encryption() - self.logger.info( - f"Started distributed training with {self.num_processes} executor processes" - ) - if spark_dataframe is not None: - input_df = spark_dataframe - else: - input_df = self.spark.range( - start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks - ) try: + spark_task_function = self._get_spark_task_function( + framework_wrapper_fn, train_object, spark_dataframe, *args, **kwargs + ) + self._check_encryption() + self.logger.info( + f"Started distributed training with {self.num_processes} executor processes" + ) + if spark_dataframe is not None: + input_df = spark_dataframe + else: + input_df = self.spark.range( + start=0, end=self.num_tasks, step=1, numPartitions=self.num_tasks + ) rows = input_df.mapInArrow( func=spark_task_function, schema="chunk binary", barrier=True ).collect() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org