This is an automated email from the ASF dual-hosted git repository.

byronhsu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f79fc2  SUBMARINE-868. Add save model to pysubmarine ModelsClient
4f79fc2 is described below

commit 4f79fc2846e66d90a723fc49645acfaf3439f2a3
Author: jeff-901 <[email protected]>
AuthorDate: Sat Jul 3 13:50:35 2021 +0800

    SUBMARINE-868. Add save model to pysubmarine ModelsClient
    
    ### What is this PR for?
    Add save model function to submarine-sdk model management.
    
    ### What type of PR is it?
    Feature
    
    ### Todos
    
    ### What is the Jira issue?
    https://issues.apache.org/jira/browse/SUBMARINE-868
    
    ### How should this be tested?
    
    ### Screenshots (if appropriate)
    
    ### Questions:
    * Do the license files need updating? No
    * Are there breaking changes for older versions? No
    * Does this need new documentation? Yes
    
    Author: jeff-901 <[email protected]>
    
    Signed-off-by: byronhsu <[email protected]>
    
    Closes #634 from jeff-901/SUBMARINE-868 and squashes the following commits:
    
    cf018633 [jeff-901] edit function name
    50655ab4 [jeff-901] add save_model
---
 .../pysubmarine/submarine/models/client.py         | 31 +++++++++++++++++-----
 .../pysubmarine/submarine/models/utils.py          | 10 +++++++
 .../pysubmarine/tests/models/test_model.py         |  9 +++----
 .../pysubmarine/tests/models/test_model_e2e.py     |  5 +++-
 4 files changed, 43 insertions(+), 12 deletions(-)

diff --git a/submarine-sdk/pysubmarine/submarine/models/client.py 
b/submarine-sdk/pysubmarine/submarine/models/client.py
index 0513ad3..084778f 100644
--- a/submarine-sdk/pysubmarine/submarine/models/client.py
+++ b/submarine-sdk/pysubmarine/submarine/models/client.py
@@ -22,7 +22,7 @@ from mlflow.tracking import MlflowClient
 
 from .constant import (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY,
                        MLFLOW_S3_ENDPOINT_URL, MLFLOW_TRACKING_URI)
-from .utils import get_job_id, get_worker_index
+from .utils import exist_ps, get_job_id, get_worker_index
 
 
 class ModelsClient():
@@ -39,6 +39,12 @@ class ModelsClient():
         os.environ["AWS_SECRET_ACCESS_KEY"] = AWS_SECRET_ACCESS_KEY
         os.environ["MLFLOW_TRACKING_URI"] = tracking_uri or MLFLOW_TRACKING_URI
         self.client = MlflowClient()
+        self.type_to_log_model = {
+            "pytorch": mlflow.pytorch.log_model,
+            "sklearn": mlflow.sklearn.log_model,
+            "tensorflow": mlflow.tensorflow.log_model,
+            "keras": mlflow.keras.log_model
+        }
 
     def start(self):
         """
@@ -66,11 +72,6 @@ class ModelsClient():
     def log_metrics(self, metrics, step=None):
         mlflow.log_metrics(metrics, step)
 
-    def log_model(self, name, checkpoint):
-        mlflow.pytorch.log_model(registered_model_name=name,
-                                 pytorch_model=checkpoint,
-                                 artifact_path="pytorch-model")
-
     def load_model(self, name, version):
         model = mlflow.pyfunc.load_model(model_uri=f"models:/{name}/{version}")
         return model
@@ -81,6 +82,24 @@ class ModelsClient():
     def delete_model(self, name, version):
         self.client.delete_model_version(name=name, version=version)
 
+    def save_model(self,
+                   model_type,
+                   model,
+                   artifact_path,
+                   registered_model_name=None):
+        run_name = get_worker_index()
+        if exist_ps():
+            # TODO for Tensorflow ParameterServer strategy
+            return
+        elif run_name == "worker-0":
+            if model_type in self.type_to_log_model:
+                self.type_to_log_model[model_type](
+                    model,
+                    artifact_path,
+                    registered_model_name=registered_model_name)
+            else:
+                raise MlflowException("No valid type of model has been 
matched")
+
     def _get_or_create_experiment(self, experiment_name):
         """
         Return the id of experiment.
diff --git a/submarine-sdk/pysubmarine/submarine/models/utils.py 
b/submarine-sdk/pysubmarine/submarine/models/utils.py
index 7bb5d5d..20bb277 100644
--- a/submarine-sdk/pysubmarine/submarine/models/utils.py
+++ b/submarine-sdk/pysubmarine/submarine/models/utils.py
@@ -25,6 +25,7 @@ _JOB_ID_ENV_VAR = "JOB_ID"
 
 _TF_CONFIG = "TF_CONFIG"
 _CLUSTER_SPEC = "CLUSTER_SPEC"
+_CLUSTER = "cluster"
 _JOB_NAME = "JOB_NAME"
 _TYPE = "type"
 _TASK = "task"
@@ -76,3 +77,12 @@ def get_worker_index():
         worker_index = "worker-0"
 
     return worker_index
+
+
+def exist_ps():
+    if env.get_env(_TF_CONFIG) is not None:
+        tf_config = json.loads(os.environ.get(_TF_CONFIG))
+        cluster = tf_config.get(_CLUSTER)
+        if "ps" in cluster:
+            return True
+    return False
diff --git a/submarine-sdk/pysubmarine/tests/models/test_model.py 
b/submarine-sdk/pysubmarine/tests/models/test_model.py
index 6d4cb8e..e0efcb7 100644
--- a/submarine-sdk/pysubmarine/tests/models/test_model.py
+++ b/submarine-sdk/pysubmarine/tests/models/test_model.py
@@ -32,14 +32,13 @@ class TestSubmarineModelsClient():
     def tearDown(self):
         pass
 
-    @pytest.mark.skip(reason="Developing")
-    def test_log_model(self, mocker):
-        mock_method = mocker.patch.object(ModelsClient, "log_model")
+    def test_save_model(self, mocker):
+        mock_method = mocker.patch.object(ModelsClient, "save_model")
         client = ModelsClient()
         model = LinearNNModel()
         name = "simple-nn-model"
-        client.log_model(name, model)
-        mock_method.assert_called_once_with("simple-nn-model", model)
+        client.save_model("pytorch", model, name)
+        mock_method.assert_called_once_with("pytorch", model, 
"simple-nn-model")
 
     def test_update_model(self, mocker):
         mock_method = mocker.patch.object(MlflowClient,
diff --git a/submarine-sdk/pysubmarine/tests/models/test_model_e2e.py 
b/submarine-sdk/pysubmarine/tests/models/test_model_e2e.py
index 63e1d6d..7fb1a55 100644
--- a/submarine-sdk/pysubmarine/tests/models/test_model_e2e.py
+++ b/submarine-sdk/pysubmarine/tests/models/test_model_e2e.py
@@ -38,7 +38,10 @@ class TestSubmarineModelsClientE2E():
         model = LinearNNModel()
         # log
         name = "simple-nn-model"
-        models_client.log_model(name, model)
+        models_client.save_model("pytorch",
+                                 model,
+                                 name,
+                                 registered_model_name=name)
         # update
         new_name = "new-simple-nn-model"
         models_client.update_model(name, new_name)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to