HyukjinKwon commented on code in PR #40724:
URL: https://github.com/apache/spark/pull/40724#discussion_r1169445512


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -643,13 +664,52 @@ def _setup_files(train_fn: Callable, *args: Any) -> 
Generator[Tuple[str, str], N
         finally:
             TorchDistributor._cleanup_files(save_dir)
 
+    @staticmethod
+    @contextmanager
+    def _setup_spark_partition_data(partition_data_iterator, 
input_schema_json):
+        from pyspark.sql.pandas.serializers import ArrowStreamSerializer
+        from pyspark.files import SparkFiles
+        import json
+
+        if input_schema_json is None:
+            yield
+            return
+
+        # We need to temporarily write partition data into a temp dir,
+        # partition data might be huge, so we need to write it under
+        # configured `SPARK_LOCAL_DIRS`.
+        save_dir = 
TorchDistributor._create_save_dir(root_dir=SparkFiles.getRootDirectory())
+
+        try:
+            serializer = ArrowStreamSerializer()
+            arrow_file_path = os.path.join(save_dir, "data.arrow")

Review Comment:
   You might need to remove if the directory already exists. If Python UDF 
fails (with Python worker reused), it will retries to use this path if I am not 
mistaken.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to