[ https://issues.apache.org/jira/browse/SPARK-41775?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Rithwik Ediga Lakhamsani updated SPARK-41775: --------------------------------------------- Component/s: ML > Implement training functions as input > ------------------------------------- > > Key: SPARK-41775 > URL: https://issues.apache.org/jira/browse/SPARK-41775 > Project: Spark > Issue Type: Sub-task > Components: ML, PySpark > Affects Versions: 3.4.0 > Reporter: Rithwik Ediga Lakhamsani > Priority: Major > > Currently, `Distributor().run(...)` takes only files as input. Now we will > add in additional functionality to take in functions as well. This will > require us to go through the following process on each task in the executor > nodes: > 1. take the input function and args and pickle them > 2. Create a temp train.py file that looks like > ```python > import cloudpickle > import os > if __name__ == "__main__": > train, args = cloudpickle.load(f"\{tempdir}/train_input.pkl") > output = train(*args) > if output and os.environ.get("RANK", "") == "0": # this is for > partitionId == 0 > cloudpickle.dump(f"\{tempdir}/train_output.pkl") > ``` > 3. Run that train.py file with `torchrun` > 4. Check if `train_output.pkl` has been created on process on partitionId == > 0, if it has, then deserialize it and return that output -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org