This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c2060e7c0a3 [SPARK-43081][ML][FOLLOW-UP] Improve torch distributor
data loader code
c2060e7c0a3 is described below
commit c2060e7c0a332c20f527adeb34a52042237430e4
Author: Weichen Xu <[email protected]>
AuthorDate: Wed May 31 16:34:19 2023 +0800
[SPARK-43081][ML][FOLLOW-UP] Improve torch distributor data loader code
### What changes were proposed in this pull request?
### Why are the changes needed?
Improve torch distributor data loader code:
* Add a verification that num_processes must match input spark dataframe
partitions. This makes user debug easier when they set mismatched input
dataframe, otherwise torch package will raise intricate error information.
* Improve column value conversion in torch dataloader. Avoid comparing type
operation for every column values.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
UT.
Closes #41382 from WeichenXu123/improve-torch-dataloader.
Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
---
python/pyspark/ml/torch/data.py | 50 ++++++++++++++++++++--------------
python/pyspark/ml/torch/distributor.py | 6 ++++
2 files changed, 36 insertions(+), 20 deletions(-)
diff --git a/python/pyspark/ml/torch/data.py b/python/pyspark/ml/torch/data.py
index a52b96a9392..d421683c16d 100644
--- a/python/pyspark/ml/torch/data.py
+++ b/python/pyspark/ml/torch/data.py
@@ -17,7 +17,7 @@
import torch
import numpy as np
-from typing import Any, Iterator
+from typing import Any, Callable, Iterator
from pyspark.sql.types import StructType
@@ -26,27 +26,37 @@ class
_SparkPartitionTorchDataset(torch.utils.data.IterableDataset):
self.arrow_file_path = arrow_file_path
self.num_samples = num_samples
self.field_types = [field.dataType.simpleString() for field in schema]
+ self.field_converters = [
+ _SparkPartitionTorchDataset._get_field_converter(field_type)
+ for field_type in self.field_types
+ ]
@staticmethod
- def _extract_field_value(value: Any, field_type: str) -> Any:
- # TODO: avoid checking field type for every row.
+ def _get_field_converter(field_type: str) -> Callable[[Any], Any]:
if field_type == "vector":
- if value["type"] == 1:
- # dense vector
- return value["values"]
- if value["type"] == 0:
- # sparse vector
- size = int(value["size"])
- sparse_array = np.zeros(size, dtype=np.float64)
- sparse_array[value["indices"]] = value["values"]
- return sparse_array
- if field_type in ["float", "double", "int", "bigint", "smallint"]:
- return value
- raise ValueError(
- "SparkPartitionTorchDataset does not support loading data from
field of "
- f"type {field_type}."
- )
+ def converter(value: Any) -> Any:
+ if value["type"] == 1:
+ # dense vector
+ return value["values"]
+ if value["type"] == 0:
+ # sparse vector
+ size = int(value["size"])
+ sparse_array = np.zeros(size, dtype=np.float64)
+ sparse_array[value["indices"]] = value["values"]
+ return sparse_array
+
+ elif field_type in ["float", "double", "int", "bigint", "smallint"]:
+
+ def converter(value: Any) -> Any:
+ return value
+
+ else:
+ raise ValueError(
+ "SparkPartitionTorchDataset does not support loading data from
field of "
+ f"type {field_type}."
+ )
+ return converter
def __iter__(self) -> Iterator[Any]:
from pyspark.sql.pandas.serializers import ArrowStreamSerializer
@@ -71,8 +81,8 @@ class
_SparkPartitionTorchDataset(torch.utils.data.IterableDataset):
batch_pdf = batch.to_pandas()
for row in batch_pdf.itertuples(index=False):
yield [
-
_SparkPartitionTorchDataset._extract_field_value(value, field_type)
- for value, field_type in zip(row, self.field_types)
+ field_converter(value)
+ for value, field_converter in zip(row,
self.field_converters)
]
count += 1
if count == self.num_samples:
diff --git a/python/pyspark/ml/torch/distributor.py
b/python/pyspark/ml/torch/distributor.py
index 0249e6b4b2c..711f76db09b 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -591,6 +591,12 @@ class TorchDistributor(Distributor):
os.environ["NODE_RANK"] = str(context.partitionId())
os.environ["RANK"] = str(context.partitionId())
+ if context.partitionId() >= num_processes:
+ raise ValueError(
+ "TorchDistributor._train_on_dataframe requires setting
num_processes "
+ "equal to input spark dataframe partition number."
+ )
+
if is_spark_local_master:
# distributed training on a local mode spark cluster
def set_gpus(context: "BarrierTaskContext") -> None:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]