rithwik-db commented on code in PR #40724:
URL: https://github.com/apache/spark/pull/40724#discussion_r1162233394


##########
python/pyspark/ml/torch/distributor.py:
##########
@@ -744,7 +814,99 @@ def run(self, train_object: Union[Callable, str], *args: 
Any) -> Optional[Any]:
                 TorchDistributor._run_training_on_pytorch_function  # type: 
ignore
             )
         if self.local_mode:
-            output = self._run_local_training(framework_wrapper_fn, 
train_object, *args)
+            output = self._run_local_training(framework_wrapper_fn, 
train_object, *args, **kwargs)
         else:
-            output = self._run_distributed_training(framework_wrapper_fn, 
train_object, *args)
+            output = self._run_distributed_training(
+                framework_wrapper_fn, train_object, None, *args, **kwargs
+            )
         return output
+
+    def train_on_dataframe(self, train_function, spark_dataframe, *args, 
**kwargs):

Review Comment:
   It makes sense to create a new API since we are including `spark_dataframe` 
here, but the original premise was having a singular `.run` function be created 
to handle all forms of input. I guess I am not opposed to it, but can you add 
an example 
[here](https://github.com/apache/spark/blob/06e3b53ef432bdb4bab23c9204f4cbe096803f57/python/pyspark/ml/torch/distributor.py#L268)
 to showcase how it would look like?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to