HyukjinKwon commented on code in PR #40724:
URL: https://github.com/apache/spark/pull/40724#discussion_r1169447530
##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -643,13 +664,52 @@ def _setup_files(train_fn: Callable, *args: Any) ->
Generator[Tuple[str, str], N
finally:
TorchDistributor._cleanup_files(save_dir)
+ @staticmethod
+ @contextmanager
+ def _setup_spark_partition_data(partition_data_iterator,
input_schema_json):
+ from pyspark.sql.pandas.serializers import ArrowStreamSerializer
+ from pyspark.files import SparkFiles
+ import json
+
+ if input_schema_json is None:
+ yield
+ return
+
+ # We need to temporarily write partition data into a temp dir,
+ # partition data might be huge, so we need to write it under
+ # configured `SPARK_LOCAL_DIRS`.
+ save_dir =
TorchDistributor._create_save_dir(root_dir=SparkFiles.getRootDirectory())
+
+ try:
+ serializer = ArrowStreamSerializer()
+ arrow_file_path = os.path.join(save_dir, "data.arrow")
Review Comment:
Oh, actually NVM. `with open(arrow_file_path, "wb") as f:` will write it
from the first.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]