You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "Weichen Xu (Jira)" <ji...@apache.org> on 2023/04/10 12:07:00 UTC
[jira] [Assigned] (SPARK-43081) Add torch distributor data loader that loads data from spark partition data
[ https://issues.apache.org/jira/browse/SPARK-43081?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Weichen Xu reassigned SPARK-43081:
----------------------------------
Assignee: Weichen Xu
> Add torch distributor data loader that loads data from spark partition data
> ---------------------------------------------------------------------------
>
> Key: SPARK-43081
> URL: https://issues.apache.org/jira/browse/SPARK-43081
> Project: Spark
> Issue Type: Sub-task
> Components: Connect, ML, PySpark
> Affects Versions: 3.5.0
> Reporter: Weichen Xu
> Assignee: Weichen Xu
> Priority: Major
>
> Add torch distributor data loader that loads data from spark partition data.
>
> We can add 2 APIs like:
> Adds a `TorchDistributor` method API :
> {code:java}
> def train_on_dataframe(self, train_function, spark_dataframe, *args, **kwargs):
> """
> Runs distributed training using provided spark DataFrame as input data.
> You should ensure the input spark DataFrame have evenly divided partitions,
> and this method starts a barrier spark job that each spark task in the job
> process one partition of the input spark DataFrame.
> Parameters
> ----------
> train_function :
> Either a PyTorch function, PyTorch Lightning function that launches distributed
> training. Note that inside the function, you can call
> `pyspark.ml.torch.distributor.get_spark_partition_data_loader` API to get a torch
> data loader, the data loader loads data from the corresponding partition of the
> input spark DataFrame.
> spark_dataframe :
> An input spark DataFrame that can be used in PyTorch `train_function` function.
> See `train_function` argument doc for details.
> args :
> `args` need to be the input parameters to `train_function` function. It would look like
> >>> model = distributor.run(train, 1e-3, 64)
> where train is a function and 1e-3 and 64 are regular numeric inputs to the function.
> kwargs :
> `kwargs` need to be the key-work input parameters to `train_function` function.
> It would look like
> >>> model = distributor.run(train, tol=1e-3, max_iter=64)
> where train is a function that has 2 arguments `tol` and `max_iter`.
> Returns
> -------
> Returns the output of `train_function` called with args inside spark rank 0 task.
> """{code}
>
> Adds an loader API:
>
> {code:java}
> def get_spark_partition_data_loader(num_samples, batch_size, prefetch=2):
> """
> This function must be called inside the `train_function` where `train_function`
> is the input argument of `TorchDistributor.train_on_dataframe`.
> The function returns a pytorch data loader that loads data from
> the corresponding spark partition data.
> Parameters
> ----------
> num_samples :
> Number of samples to generate per epoch. If `num_samples` is less than the number of
> rows in the spark partition, it generate the first `num_samples` rows of
> the spark partition, if `num_samples` is greater than the number of
> rows in the spark partition, then after the iterator loaded all rows from the partition,
> it wraps round back to the first row.
> batch_size:
> How many samples per batch to load.
> prefetch:
> Number of batches loaded in advance.
> """{code}
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org