This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a758d6a0f9d [SPARK-43081][FOLLOW-UP][ML][CONNECT] Make torch
dataloader support torch 1.x
a758d6a0f9d is described below
commit a758d6a0f9dfa32881cfcec263da0ab0c02f5c1d
Author: Weichen Xu <[email protected]>
AuthorDate: Tue Jun 27 07:57:21 2023 -0700
[SPARK-43081][FOLLOW-UP][ML][CONNECT] Make torch dataloader support torch
1.x
### What changes were proposed in this pull request?
Make torch dataloader support torch 1.x.
Currently, when running with torch 1.x with num_workers > 0, an error is
raised like:
```
ValueError: prefetch_factor option could only be specified in
multiprocessing.let num_workers > 0 to enable multiprocessing.
```
### Why are the changes needed?
Compatibility fix.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Manually run unit tests with torch 1.x
Closes #41751 from WeichenXu123/support-torch-1.x.
Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/ml/torch/distributor.py | 9 ++++++++-
1 file changed, 8 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/ml/torch/distributor.py
b/python/pyspark/ml/torch/distributor.py
index 9f9636e6b10..8b34acd959e 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -995,4 +995,11 @@ def _get_spark_partition_data_loader(
dataset = _SparkPartitionTorchDataset(arrow_file, schema, num_samples)
- return DataLoader(dataset, batch_size, num_workers=num_workers,
prefetch_factor=prefetch_factor)
+ if num_workers > 0:
+ return DataLoader(
+ dataset, batch_size, num_workers=num_workers,
prefetch_factor=prefetch_factor
+ )
+ else:
+ # if num_workers is zero, we cannot set `prefetch_factor` otherwise
+ # torch will raise error.
+ return DataLoader(dataset, batch_size, num_workers=num_workers)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]