WeichenXu123 commented on code in PR #40724:
URL: https://github.com/apache/spark/pull/40724#discussion_r1161679898
##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -668,13 +668,17 @@ def _setup_files(train_fn: Callable, *args: Any,
**kwargs) -> Generator[Tuple[st
@contextmanager
def _setup_spark_partition_data(partition_data_iterator,
input_schema_json):
from pyspark.sql.pandas.serializers import ArrowStreamSerializer
+ from pyspark.files import SparkFiles
import json
if input_schema_json is None:
yield
return
- save_dir = TorchDistributor._create_save_dir()
+ # We need to temporarily write partition data into a temp dir,
+ # partition data might be huge, so we need to write it under
+ # configured `SPARK_LOCAL_DIRS`.
+ save_dir =
TorchDistributor._create_save_dir(root_dir=SparkFiles.getRootDirectory())
Review Comment:
@HyukjinKwon Would you help check this part ? :)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]