This is an automated email from the ASF dual-hosted git repository.
byronhsu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push:
new 4f79fc2 SUBMARINE-868. Add save model to pysubmarine ModelsClient
4f79fc2 is described below
commit 4f79fc2846e66d90a723fc49645acfaf3439f2a3
Author: jeff-901 <[email protected]>
AuthorDate: Sat Jul 3 13:50:35 2021 +0800
SUBMARINE-868. Add save model to pysubmarine ModelsClient
### What is this PR for?
Add save model function to submarine-sdk model management.
### What type of PR is it?
Feature
### Todos
### What is the Jira issue?
https://issues.apache.org/jira/browse/SUBMARINE-868
### How should this be tested?
### Screenshots (if appropriate)
### Questions:
* Do the license files need updating? No
* Are there breaking changes for older versions? No
* Does this need new documentation? Yes
Author: jeff-901 <[email protected]>
Signed-off-by: byronhsu <[email protected]>
Closes #634 from jeff-901/SUBMARINE-868 and squashes the following commits:
cf018633 [jeff-901] edit function name
50655ab4 [jeff-901] add save_model
---
.../pysubmarine/submarine/models/client.py | 31 +++++++++++++++++-----
.../pysubmarine/submarine/models/utils.py | 10 +++++++
.../pysubmarine/tests/models/test_model.py | 9 +++----
.../pysubmarine/tests/models/test_model_e2e.py | 5 +++-
4 files changed, 43 insertions(+), 12 deletions(-)
diff --git a/submarine-sdk/pysubmarine/submarine/models/client.py
b/submarine-sdk/pysubmarine/submarine/models/client.py
index 0513ad3..084778f 100644
--- a/submarine-sdk/pysubmarine/submarine/models/client.py
+++ b/submarine-sdk/pysubmarine/submarine/models/client.py
@@ -22,7 +22,7 @@ from mlflow.tracking import MlflowClient
from .constant import (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY,
MLFLOW_S3_ENDPOINT_URL, MLFLOW_TRACKING_URI)
-from .utils import get_job_id, get_worker_index
+from .utils import exist_ps, get_job_id, get_worker_index
class ModelsClient():
@@ -39,6 +39,12 @@ class ModelsClient():
os.environ["AWS_SECRET_ACCESS_KEY"] = AWS_SECRET_ACCESS_KEY
os.environ["MLFLOW_TRACKING_URI"] = tracking_uri or MLFLOW_TRACKING_URI
self.client = MlflowClient()
+ self.type_to_log_model = {
+ "pytorch": mlflow.pytorch.log_model,
+ "sklearn": mlflow.sklearn.log_model,
+ "tensorflow": mlflow.tensorflow.log_model,
+ "keras": mlflow.keras.log_model
+ }
def start(self):
"""
@@ -66,11 +72,6 @@ class ModelsClient():
def log_metrics(self, metrics, step=None):
mlflow.log_metrics(metrics, step)
- def log_model(self, name, checkpoint):
- mlflow.pytorch.log_model(registered_model_name=name,
- pytorch_model=checkpoint,
- artifact_path="pytorch-model")
-
def load_model(self, name, version):
model = mlflow.pyfunc.load_model(model_uri=f"models:/{name}/{version}")
return model
@@ -81,6 +82,24 @@ class ModelsClient():
def delete_model(self, name, version):
self.client.delete_model_version(name=name, version=version)
+ def save_model(self,
+ model_type,
+ model,
+ artifact_path,
+ registered_model_name=None):
+ run_name = get_worker_index()
+ if exist_ps():
+ # TODO for Tensorflow ParameterServer strategy
+ return
+ elif run_name == "worker-0":
+ if model_type in self.type_to_log_model:
+ self.type_to_log_model[model_type](
+ model,
+ artifact_path,
+ registered_model_name=registered_model_name)
+ else:
+ raise MlflowException("No valid type of model has been
matched")
+
def _get_or_create_experiment(self, experiment_name):
"""
Return the id of experiment.
diff --git a/submarine-sdk/pysubmarine/submarine/models/utils.py
b/submarine-sdk/pysubmarine/submarine/models/utils.py
index 7bb5d5d..20bb277 100644
--- a/submarine-sdk/pysubmarine/submarine/models/utils.py
+++ b/submarine-sdk/pysubmarine/submarine/models/utils.py
@@ -25,6 +25,7 @@ _JOB_ID_ENV_VAR = "JOB_ID"
_TF_CONFIG = "TF_CONFIG"
_CLUSTER_SPEC = "CLUSTER_SPEC"
+_CLUSTER = "cluster"
_JOB_NAME = "JOB_NAME"
_TYPE = "type"
_TASK = "task"
@@ -76,3 +77,12 @@ def get_worker_index():
worker_index = "worker-0"
return worker_index
+
+
+def exist_ps():
+ if env.get_env(_TF_CONFIG) is not None:
+ tf_config = json.loads(os.environ.get(_TF_CONFIG))
+ cluster = tf_config.get(_CLUSTER)
+ if "ps" in cluster:
+ return True
+ return False
diff --git a/submarine-sdk/pysubmarine/tests/models/test_model.py
b/submarine-sdk/pysubmarine/tests/models/test_model.py
index 6d4cb8e..e0efcb7 100644
--- a/submarine-sdk/pysubmarine/tests/models/test_model.py
+++ b/submarine-sdk/pysubmarine/tests/models/test_model.py
@@ -32,14 +32,13 @@ class TestSubmarineModelsClient():
def tearDown(self):
pass
- @pytest.mark.skip(reason="Developing")
- def test_log_model(self, mocker):
- mock_method = mocker.patch.object(ModelsClient, "log_model")
+ def test_save_model(self, mocker):
+ mock_method = mocker.patch.object(ModelsClient, "save_model")
client = ModelsClient()
model = LinearNNModel()
name = "simple-nn-model"
- client.log_model(name, model)
- mock_method.assert_called_once_with("simple-nn-model", model)
+ client.save_model("pytorch", model, name)
+ mock_method.assert_called_once_with("pytorch", model,
"simple-nn-model")
def test_update_model(self, mocker):
mock_method = mocker.patch.object(MlflowClient,
diff --git a/submarine-sdk/pysubmarine/tests/models/test_model_e2e.py
b/submarine-sdk/pysubmarine/tests/models/test_model_e2e.py
index 63e1d6d..7fb1a55 100644
--- a/submarine-sdk/pysubmarine/tests/models/test_model_e2e.py
+++ b/submarine-sdk/pysubmarine/tests/models/test_model_e2e.py
@@ -38,7 +38,10 @@ class TestSubmarineModelsClientE2E():
model = LinearNNModel()
# log
name = "simple-nn-model"
- models_client.log_model(name, model)
+ models_client.save_model("pytorch",
+ model,
+ name,
+ registered_model_name=name)
# update
new_name = "new-simple-nn-model"
models_client.update_model(name, new_name)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]