KUAN-HSUN-LI commented on a change in pull request #772:
URL: https://github.com/apache/submarine/pull/772#discussion_r725605761



##########
File path: submarine-sdk/pysubmarine/submarine/tracking/utils.py
##########
@@ -88,5 +87,13 @@ def get_worker_index():
     return worker_index
 
 
-def get_sqlalchemy_store(store_uri):
+def get_sqlalchemy_store(store_uri: str):

Review comment:
       Will `get_tracking_sqlalchemy_store` be better?

##########
File path: submarine-sdk/pysubmarine/submarine/tracking/client.py
##########
@@ -64,3 +90,38 @@ def log_param(self, job_id, key, value, worker_index):
         validate_param(key, value)
         param = Param(key, str(value), worker_index)
         self.store.log_param(job_id, param)
+
+    def save_model(
+        self, model_type: str, model, artifact_path: str, 
registered_model_name: str = None
+    ) -> None:

Review comment:
       Add the comment on the situation that the `registered_model_name` is 
None or not None

##########
File path: submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
##########
@@ -73,3 +79,25 @@ def test_log_metric(self):
             assert metrics[0].value == 5
             assert metrics[0].id == JOB_ID
             assert metrics[1].value == 6
+
+    @pytest.mark.skip(reason="using tensorflow 2")

Review comment:
       It should be `@pytest.mark.skip(tf.version.VERSION < '2.0', 
reason="using tensorflow 2")` ?

##########
File path: submarine-sdk/pysubmarine/submarine/store/__init__.py
##########
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-DEFAULT_SUBMARINE_JDBC_URL = 
"mysql+pymysql://submarine:password@localhost:3306/submarine"
+
+DEFAULT_SUBMARINE_JDBC_URL = 
"mysql+pymysql://submarine:password@submarine-database:3306/submarine"

Review comment:
       Can you also replace the `localhost` with `submarine-database` in other 
files? And also remove the port forward in the integration test of GitHub 
Actions. 

##########
File path: submarine-sdk/pysubmarine/submarine/artifacts/repository.py
##########
@@ -71,3 +71,4 @@ def log_artifacts(self, local_dir, artifact_path):
                     bucket=bucket,
                     key=os.path.join(upload_path, f),
                 )
+        return f"s3://{bucket}/{dest_path}"

Review comment:
       Also, add the type annotation in this file. Remember the return type.

##########
File path: submarine-sdk/pysubmarine/submarine/tracking/fluent.py
##########
@@ -52,3 +52,7 @@ def log_metric(key, value, step=None):
     job_id = get_job_id()
     worker_index = get_worker_index()
     SubmarineClient().log_metric(job_id, key, value, worker_index, 
datetime.now(), step or 0)
+
+
+def save_model(self, model_type: str, model, artifact_path: str, 
registered_model_name: str = None):
+    SubmarineClient().save_model(model_type, model, artifact_path, 
registered_model_name)

Review comment:
       Add the same comment as above.

##########
File path: submarine-sdk/pysubmarine/submarine/tracking/client.py
##########
@@ -64,3 +90,38 @@ def log_param(self, job_id, key, value, worker_index):
         validate_param(key, value)
         param = Param(key, str(value), worker_index)
         self.store.log_param(job_id, param)
+
+    def save_model(
+        self, model_type: str, model, artifact_path: str, 
registered_model_name: str = None
+    ) -> None:
+        pattern = r"[0-9A-Za-z][0-9A-Za-z-_]*[0-9A-Za-z]|[0-9A-Za-z]"
+        if not re.fullmatch(pattern, artifact_path):
+            raise Exception(
+                "Artifact_path must only contains numbers, characters, hyphen 
and underscore.      "
+                "        Artifact_path must starts and ends with numbers or 
characters."
+            )
+        with tempfile.TemporaryDirectory() as tempdir:
+            if model_type == "pytorch":
+                import submarine.models.pytorch
+
+                submarine.models.pytorch.save_model(model, tempdir)
+            elif model_type == "tensorflow":
+                import submarine.models.tensorflow
+
+                submarine.models.tensorflow.save_model(model, tempdir)
+            else:
+                raise Exception("No valid type of model has been matched to 
{}".format(model_type))
+            source = self.artifact_repo.log_artifacts(tempdir, artifact_path)
+
+        # Register model
+        if registered_model_name is not None:
+            try:
+                self.model_registry.get_registered_model(registered_model_name)
+            except SubmarineException:
+                
self.model_registry.create_registered_model(name=registered_model_name)
+            self.model_registry.create_model_version(
+                name=registered_model_name,
+                source=source,
+                user_id="TODO",

Review comment:
       Add a TODO comment

##########
File path: submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
##########
@@ -73,3 +79,25 @@ def test_log_metric(self):
             assert metrics[0].value == 5
             assert metrics[0].id == JOB_ID
             assert metrics[1].value == 6
+
+    @pytest.mark.skip(reason="using tensorflow 2")
+    def test_save_model(self):
+        model = LinearNNModel()
+        registered_model_name = "registerd_model_name"
+        submarine.save_model("tensorflow", model, "name_1", 
registered_model_name)
+        submarine.save_model("tensorflow", model, "name_2", 
registered_model_name)
+        # Validate model_versions
+        with self.model_registry.ManagedSessionMaker() as session:
+            model_versions = (
+                session.query(SqlModelVersion)
+                .options()
+                .filter(SqlModelVersion.name == registered_model_name)
+                .all()
+            )

Review comment:
       1. Simplify with `model_versions = 
self.model_registry.list_model_versions(registered_model_name)`?
   2. Can you also delete all the models in `s3 bucket` when the test tear down?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to