[ https://issues.apache.org/jira/browse/SPARK-41775?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=17687743#comment-17687743 ]
Apache Spark commented on SPARK-41775: -------------------------------------- User 'rithwik-db' has created a pull request for this issue: https://github.com/apache/spark/pull/39987 > Implement training functions as input > ------------------------------------- > > Key: SPARK-41775 > URL: https://issues.apache.org/jira/browse/SPARK-41775 > Project: Spark > Issue Type: Sub-task > Components: ML, PySpark > Affects Versions: 3.4.0 > Reporter: Rithwik Ediga Lakhamsani > Assignee: Rithwik Ediga Lakhamsani > Priority: Major > Fix For: 3.4.0 > > > Sidenote: make formatting updates described in > https://github.com/apache/spark/pull/39188 > > Currently, `Distributor().run(...)` takes only files as input. Now we will > add in additional functionality to take in functions as well. This will > require us to go through the following process on each task in the executor > nodes: > 1. take the input function and args and pickle them > 2. Create a temp train.py file that looks like > {code:java} > import cloudpickle > import os > if _name_ == "_main_": > train, args = cloudpickle.load(f"{tempdir}/train_input.pkl") > output = train(*args) > if output and os.environ.get("RANK", "") == "0": # this is for > partitionId == 0 > cloudpickle.dump(f"{tempdir}/train_output.pkl") {code} > 3. Run that train.py file with `torchrun` > 4. Check if `train_output.pkl` has been created on process on partitionId == > 0, if it has, then deserialize it and return that output through `.collect()` -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org