This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c2060e7c0a3 [SPARK-43081][ML][FOLLOW-UP] Improve torch distributor 
data loader code
c2060e7c0a3 is described below

commit c2060e7c0a332c20f527adeb34a52042237430e4
Author: Weichen Xu <[email protected]>
AuthorDate: Wed May 31 16:34:19 2023 +0800

    [SPARK-43081][ML][FOLLOW-UP] Improve torch distributor data loader code
    
    ### What changes were proposed in this pull request?
    
    ### Why are the changes needed?
    
    Improve torch distributor data loader code:
    
    * Add a verification that num_processes must match input spark dataframe 
partitions. This makes user debug easier when they set mismatched input 
dataframe, otherwise torch package will raise intricate error information.
    * Improve column value conversion in torch dataloader. Avoid comparing type 
operation for every column values.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    UT.
    
    Closes #41382 from WeichenXu123/improve-torch-dataloader.
    
    Authored-by: Weichen Xu <[email protected]>
    Signed-off-by: Weichen Xu <[email protected]>
---
 python/pyspark/ml/torch/data.py        | 50 ++++++++++++++++++++--------------
 python/pyspark/ml/torch/distributor.py |  6 ++++
 2 files changed, 36 insertions(+), 20 deletions(-)

diff --git a/python/pyspark/ml/torch/data.py b/python/pyspark/ml/torch/data.py
index a52b96a9392..d421683c16d 100644
--- a/python/pyspark/ml/torch/data.py
+++ b/python/pyspark/ml/torch/data.py
@@ -17,7 +17,7 @@
 
 import torch
 import numpy as np
-from typing import Any, Iterator
+from typing import Any, Callable, Iterator
 from pyspark.sql.types import StructType
 
 
@@ -26,27 +26,37 @@ class 
_SparkPartitionTorchDataset(torch.utils.data.IterableDataset):
         self.arrow_file_path = arrow_file_path
         self.num_samples = num_samples
         self.field_types = [field.dataType.simpleString() for field in schema]
+        self.field_converters = [
+            _SparkPartitionTorchDataset._get_field_converter(field_type)
+            for field_type in self.field_types
+        ]
 
     @staticmethod
-    def _extract_field_value(value: Any, field_type: str) -> Any:
-        # TODO: avoid checking field type for every row.
+    def _get_field_converter(field_type: str) -> Callable[[Any], Any]:
         if field_type == "vector":
-            if value["type"] == 1:
-                # dense vector
-                return value["values"]
-            if value["type"] == 0:
-                # sparse vector
-                size = int(value["size"])
-                sparse_array = np.zeros(size, dtype=np.float64)
-                sparse_array[value["indices"]] = value["values"]
-                return sparse_array
-        if field_type in ["float", "double", "int", "bigint", "smallint"]:
-            return value
 
-        raise ValueError(
-            "SparkPartitionTorchDataset does not support loading data from 
field of "
-            f"type {field_type}."
-        )
+            def converter(value: Any) -> Any:
+                if value["type"] == 1:
+                    # dense vector
+                    return value["values"]
+                if value["type"] == 0:
+                    # sparse vector
+                    size = int(value["size"])
+                    sparse_array = np.zeros(size, dtype=np.float64)
+                    sparse_array[value["indices"]] = value["values"]
+                    return sparse_array
+
+        elif field_type in ["float", "double", "int", "bigint", "smallint"]:
+
+            def converter(value: Any) -> Any:
+                return value
+
+        else:
+            raise ValueError(
+                "SparkPartitionTorchDataset does not support loading data from 
field of "
+                f"type {field_type}."
+            )
+        return converter
 
     def __iter__(self) -> Iterator[Any]:
         from pyspark.sql.pandas.serializers import ArrowStreamSerializer
@@ -71,8 +81,8 @@ class 
_SparkPartitionTorchDataset(torch.utils.data.IterableDataset):
                     batch_pdf = batch.to_pandas()
                     for row in batch_pdf.itertuples(index=False):
                         yield [
-                            
_SparkPartitionTorchDataset._extract_field_value(value, field_type)
-                            for value, field_type in zip(row, self.field_types)
+                            field_converter(value)
+                            for value, field_converter in zip(row, 
self.field_converters)
                         ]
                         count += 1
                         if count == self.num_samples:
diff --git a/python/pyspark/ml/torch/distributor.py 
b/python/pyspark/ml/torch/distributor.py
index 0249e6b4b2c..711f76db09b 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -591,6 +591,12 @@ class TorchDistributor(Distributor):
                 os.environ["NODE_RANK"] = str(context.partitionId())
                 os.environ["RANK"] = str(context.partitionId())
 
+                if context.partitionId() >= num_processes:
+                    raise ValueError(
+                        "TorchDistributor._train_on_dataframe requires setting 
num_processes "
+                        "equal to input spark dataframe partition number."
+                    )
+
             if is_spark_local_master:
                 # distributed training on a local mode spark cluster
                 def set_gpus(context: "BarrierTaskContext") -> None:


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to