[
https://issues.apache.org/jira/browse/SPARK-41775?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=17687743#comment-17687743
]
Apache Spark commented on SPARK-41775:
--------------------------------------
User 'rithwik-db' has created a pull request for this issue:
https://github.com/apache/spark/pull/39987
> Implement training functions as input
> -------------------------------------
>
> Key: SPARK-41775
> URL: https://issues.apache.org/jira/browse/SPARK-41775
> Project: Spark
> Issue Type: Sub-task
> Components: ML, PySpark
> Affects Versions: 3.4.0
> Reporter: Rithwik Ediga Lakhamsani
> Assignee: Rithwik Ediga Lakhamsani
> Priority: Major
> Fix For: 3.4.0
>
>
> Sidenote: make formatting updates described in
> https://github.com/apache/spark/pull/39188
>
> Currently, `Distributor().run(...)` takes only files as input. Now we will
> add in additional functionality to take in functions as well. This will
> require us to go through the following process on each task in the executor
> nodes:
> 1. take the input function and args and pickle them
> 2. Create a temp train.py file that looks like
> {code:java}
> import cloudpickle
> import os
> if _name_ == "_main_":
> train, args = cloudpickle.load(f"{tempdir}/train_input.pkl")
> output = train(*args)
> if output and os.environ.get("RANK", "") == "0": # this is for
> partitionId == 0
> cloudpickle.dump(f"{tempdir}/train_output.pkl") {code}
> 3. Run that train.py file with `torchrun`
> 4. Check if `train_output.pkl` has been created on process on partitionId ==
> 0, if it has, then deserialize it and return that output through `.collect()`
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]