[ 
https://issues.apache.org/jira/browse/SPARK-43081?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Weichen Xu updated SPARK-43081:
-------------------------------
    Description: 
Add torch distributor data loader that loads data from spark partition data.

 

We can add 2 APIs like:

Adds a `TorchDistributor` method API :
{code:java}
     def train_on_dataframe(self, train_function, spark_dataframe, *args, 
**kwargs):
        """
        Runs distributed training using provided spark DataFrame as input data.
        You should ensure the input spark DataFrame have evenly divided 
partitions,
        and this method starts a barrier spark job that each spark task in the 
job
        process one partition of the input spark DataFrame.
        Parameters
        ----------
        train_function :
            Either a PyTorch function, PyTorch Lightning function that launches 
distributed
            training. Note that inside the function, you can call
            `pyspark.ml.torch.distributor.get_spark_partition_data_loader` API 
to get a torch
            data loader, the data loader loads data from the corresponding 
partition of the
            input spark DataFrame.
        spark_dataframe :
            An input spark DataFrame that can be used in PyTorch 
`train_function` function.
            See `train_function` argument doc for details.
        args :
            `args` need to be the input parameters to `train_function` 
function. It would look like
            >>> model = distributor.run(train, 1e-3, 64)
            where train is a function and 1e-3 and 64 are regular numeric 
inputs to the function.
        kwargs :
            `kwargs` need to be the key-work input parameters to 
`train_function` function.
            It would look like
            >>> model = distributor.run(train, tol=1e-3, max_iter=64)
            where train is a function that has 2 arguments `tol` and `max_iter`.
        Returns
        -------
            Returns the output of `train_function` called with args inside 
spark rank 0 task.
        """{code}
 

Adds an loader API:

 
{code:java}
 def get_spark_partition_data_loader(num_samples, batch_size, prefetch=2):
    """
    This function must be called inside the `train_function` where 
`train_function`
    is the input argument of `TorchDistributor.train_on_dataframe`.
    The function returns a pytorch data loader that loads data from
    the corresponding spark partition data.
    Parameters
    ----------
    num_samples :
        Number of samples to generate per epoch. If `num_samples` is less than 
the number of
        rows in the spark partition, it generate the first `num_samples` rows of
        the spark partition, if `num_samples` is greater than the number of
        rows in the spark partition, then after the iterator loaded all rows 
from the partition,
        it wraps round back to the first row.
    batch_size:
        How many samples per batch to load.
    prefetch:
        Number of batches loaded in advance.
    """{code}

  was:
Add torch distributor data loader that loads data from spark partition data.

 

We can add 2 APIs like:

 

Adds a `TorchDistributor` method API :

```
    def train_on_dataframe(self, train_function, spark_dataframe, *args, 
**kwargs):
        """
        Runs distributed training using provided spark DataFrame as input data.
        You should ensure the input spark DataFrame have evenly divided 
partitions,
        and this method starts a barrier spark job that each spark task in the 
job
        process one partition of the input spark DataFrame.

        Parameters
        ----------
        train_function :
            Either a PyTorch function, PyTorch Lightning function that launches 
distributed
            training. Note that inside the function, you can call
            `pyspark.ml.torch.distributor.get_spark_partition_data_loader` API 
to get a torch
            data loader, the data loader loads data from the corresponding 
partition of the
            input spark DataFrame.
        spark_dataframe :
            An input spark DataFrame that can be used in PyTorch 
`train_function` function.
            See `train_function` argument doc for details.
        args :
            `args` need to be the input parameters to `train_function` 
function. It would look like

            >>> model = distributor.run(train, 1e-3, 64)

            where train is a function and 1e-3 and 64 are regular numeric 
inputs to the function.
        kwargs :
            `kwargs` need to be the key-work input parameters to 
`train_function` function.
            It would look like

            >>> model = distributor.run(train, tol=1e-3, max_iter=64)

            where train is a function that has 2 arguments `tol` and `max_iter`.

        Returns
        -------
            Returns the output of `train_function` called with args inside 
spark rank 0 task.
        """
```

 

Adds an loader API:

```
def get_spark_partition_data_loader(num_samples, batch_size, prefetch=2):
    """
    This function must be called inside the `train_function` where 
`train_function`
    is the input argument of `TorchDistributor.train_on_dataframe`.
    The function returns a pytorch data loader that loads data from
    the corresponding spark partition data.

    Parameters
    ----------
    num_samples :
        Number of samples to generate per epoch. If `num_samples` is less than 
the number of
        rows in the spark partition, it generate the first `num_samples` rows of
        the spark partition, if `num_samples` is greater than the number of
        rows in the spark partition, then after the iterator loaded all rows 
from the partition,
        it wraps round back to the first row.
    batch_size:
        How many samples per batch to load.
    prefetch:
        Number of batches loaded in advance.
    """
```


> Add torch distributor data loader that loads data from spark partition data
> ---------------------------------------------------------------------------
>
>                 Key: SPARK-43081
>                 URL: https://issues.apache.org/jira/browse/SPARK-43081
>             Project: Spark
>          Issue Type: Sub-task
>          Components: Connect, ML, PySpark
>    Affects Versions: 3.5.0
>            Reporter: Weichen Xu
>            Priority: Major
>
> Add torch distributor data loader that loads data from spark partition data.
>  
> We can add 2 APIs like:
> Adds a `TorchDistributor` method API :
> {code:java}
>      def train_on_dataframe(self, train_function, spark_dataframe, *args, 
> **kwargs):
>         """
>         Runs distributed training using provided spark DataFrame as input 
> data.
>         You should ensure the input spark DataFrame have evenly divided 
> partitions,
>         and this method starts a barrier spark job that each spark task in 
> the job
>         process one partition of the input spark DataFrame.
>         Parameters
>         ----------
>         train_function :
>             Either a PyTorch function, PyTorch Lightning function that 
> launches distributed
>             training. Note that inside the function, you can call
>             `pyspark.ml.torch.distributor.get_spark_partition_data_loader` 
> API to get a torch
>             data loader, the data loader loads data from the corresponding 
> partition of the
>             input spark DataFrame.
>         spark_dataframe :
>             An input spark DataFrame that can be used in PyTorch 
> `train_function` function.
>             See `train_function` argument doc for details.
>         args :
>             `args` need to be the input parameters to `train_function` 
> function. It would look like
>             >>> model = distributor.run(train, 1e-3, 64)
>             where train is a function and 1e-3 and 64 are regular numeric 
> inputs to the function.
>         kwargs :
>             `kwargs` need to be the key-work input parameters to 
> `train_function` function.
>             It would look like
>             >>> model = distributor.run(train, tol=1e-3, max_iter=64)
>             where train is a function that has 2 arguments `tol` and 
> `max_iter`.
>         Returns
>         -------
>             Returns the output of `train_function` called with args inside 
> spark rank 0 task.
>         """{code}
>  
> Adds an loader API:
>  
> {code:java}
>  def get_spark_partition_data_loader(num_samples, batch_size, prefetch=2):
>     """
>     This function must be called inside the `train_function` where 
> `train_function`
>     is the input argument of `TorchDistributor.train_on_dataframe`.
>     The function returns a pytorch data loader that loads data from
>     the corresponding spark partition data.
>     Parameters
>     ----------
>     num_samples :
>         Number of samples to generate per epoch. If `num_samples` is less 
> than the number of
>         rows in the spark partition, it generate the first `num_samples` rows 
> of
>         the spark partition, if `num_samples` is greater than the number of
>         rows in the spark partition, then after the iterator loaded all rows 
> from the partition,
>         it wraps round back to the first row.
>     batch_size:
>         How many samples per batch to load.
>     prefetch:
>         Number of batches loaded in advance.
>     """{code}



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to