You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "Weichen Xu (Jira)" <ji...@apache.org> on 2023/04/30 11:59:00 UTC

[jira] [Resolved] (SPARK-43081) Add torch distributor data loader that loads data from spark partition data

     [ https://issues.apache.org/jira/browse/SPARK-43081?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]

Weichen Xu resolved SPARK-43081.
--------------------------------
    Target Version/s: 3.5.0
          Resolution: Done

> Add torch distributor data loader that loads data from spark partition data
> ---------------------------------------------------------------------------
>
>                 Key: SPARK-43081
>                 URL: https://issues.apache.org/jira/browse/SPARK-43081
>             Project: Spark
>          Issue Type: Sub-task
>          Components: Connect, ML, PySpark
>    Affects Versions: 3.5.0
>            Reporter: Weichen Xu
>            Assignee: Weichen Xu
>            Priority: Major
>
> Add torch distributor data loader that loads data from spark partition data.
>  
> We can add 2 APIs like:
> Adds a `TorchDistributor` method API :
> {code:java}
>      def train_on_dataframe(self, train_function, spark_dataframe, *args, **kwargs):
>         """
>         Runs distributed training using provided spark DataFrame as input data.
>         You should ensure the input spark DataFrame have evenly divided partitions,
>         and this method starts a barrier spark job that each spark task in the job
>         process one partition of the input spark DataFrame.
>         Parameters
>         ----------
>         train_function :
>             Either a PyTorch function, PyTorch Lightning function that launches distributed
>             training. Note that inside the function, you can call
>             `pyspark.ml.torch.distributor.get_spark_partition_data_loader` API to get a torch
>             data loader, the data loader loads data from the corresponding partition of the
>             input spark DataFrame.
>         spark_dataframe :
>             An input spark DataFrame that can be used in PyTorch `train_function` function.
>             See `train_function` argument doc for details.
>         args :
>             `args` need to be the input parameters to `train_function` function. It would look like
>             >>> model = distributor.run(train, 1e-3, 64)
>             where train is a function and 1e-3 and 64 are regular numeric inputs to the function.
>         kwargs :
>             `kwargs` need to be the key-work input parameters to `train_function` function.
>             It would look like
>             >>> model = distributor.run(train, tol=1e-3, max_iter=64)
>             where train is a function that has 2 arguments `tol` and `max_iter`.
>         Returns
>         -------
>             Returns the output of `train_function` called with args inside spark rank 0 task.
>         """{code}
>  
> Adds an loader API:
>  
> {code:java}
>  def get_spark_partition_data_loader(num_samples, batch_size, prefetch=2):
>     """
>     This function must be called inside the `train_function` where `train_function`
>     is the input argument of `TorchDistributor.train_on_dataframe`.
>     The function returns a pytorch data loader that loads data from
>     the corresponding spark partition data.
>     Parameters
>     ----------
>     num_samples :
>         Number of samples to generate per epoch. If `num_samples` is less than the number of
>         rows in the spark partition, it generate the first `num_samples` rows of
>         the spark partition, if `num_samples` is greater than the number of
>         rows in the spark partition, then after the iterator loaded all rows from the partition,
>         it wraps round back to the first row.
>     batch_size:
>         How many samples per batch to load.
>     prefetch:
>         Number of batches loaded in advance.
>     """{code}



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org