This is an automated email from the ASF dual-hosted git repository. pingsutw pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push: new 535770e SUBMARINE-1005. Register model version when saving model. 535770e is described below commit 535770eaa1935d04d5bea13d326406e7c9286d77 Author: jeff-901 <b07901...@ntu.edu.tw> AuthorDate: Sun Oct 17 11:38:51 2021 +0800 SUBMARINE-1005. Register model version when saving model. ### What is this PR for? Migrate save model function to SubmarineClient. Implement logic of register model in save model function. ### What type of PR is it? Feature ### Todos ### What is the Jira issue? https://issues.apache.org/jira/browse/SUBMARINE-1005 ### How should this be tested? github action. ### Screenshots (if appropriate) ### Questions: * Do the license files need updating? No * Are there breaking changes for older versions? No * Does this need new documentation? No Author: jeff-901 <b07901...@ntu.edu.tw> Signed-off-by: Kevin <pings...@apache.org> Closes #772 from jeff-901/SUBMARINE-1005 and squashes the following commits: 3371b5fe [jeff-901] fix duplicate 7a6cbe91 [jeff-901] fix typo 091ae50f [jeff-901] edit document and fix test d16c93a1 [jeff-901] fix bugs 0f4a9069 [jeff-901] fix model client c629f25b [jeff-901] add test and remove duplicate code b70ceedb [jeff-901] add mypy syntax 03f467e7 [jeff-901] checkstyle e4e13834 [jeff-901] refactor alchemy_store 36631d2c [jeff-901] add save model in submarine client --- submarine-sdk/pysubmarine/submarine/__init__.py | 2 + .../pysubmarine/submarine/artifacts/repository.py | 12 +++- .../pysubmarine/submarine/models/client.py | 26 -------- .../pysubmarine/submarine/store/__init__.py | 3 +- .../submarine/store/{ => tracking}/__init__.py | 4 -- .../store/{ => tracking}/abstract_store.py | 0 .../store/{ => tracking}/sqlalchemy_store.py | 7 +- .../pysubmarine/submarine/tracking/client.py | 77 ++++++++++++++++++++-- .../pysubmarine/submarine/tracking/constant.py | 20 ++++++ .../pysubmarine/submarine/tracking/fluent.py | 12 ++++ .../pysubmarine/submarine/tracking/utils.py | 11 +++- .../tests/store/tracking/test_sqlalchemy_store.py | 2 +- .../pysubmarine/tests/tracking/test_tracking.py | 38 ++++++++++- .../pysubmarine/tests/tracking/test_utils.py | 10 +-- .../pysubmarine/tests/tracking/tf_model.py | 27 ++++++++ website/docs/userDocs/submarine-sdk/tracking.md | 2 +- .../userDocs/submarine-sdk/tracking.md | 2 +- 17 files changed, 205 insertions(+), 50 deletions(-) diff --git a/submarine-sdk/pysubmarine/submarine/__init__.py b/submarine-sdk/pysubmarine/submarine/__init__.py index 85519e8..0554922 100644 --- a/submarine-sdk/pysubmarine/submarine/__init__.py +++ b/submarine-sdk/pysubmarine/submarine/__init__.py @@ -20,12 +20,14 @@ from submarine.models.client import ModelsClient log_param = submarine.tracking.fluent.log_param log_metric = submarine.tracking.fluent.log_metric +save_model = submarine.tracking.fluent.save_model set_db_uri = utils.set_db_uri get_db_uri = utils.get_db_uri __all__ = [ "log_metric", "log_param", + "save_model", "set_db_uri", "get_db_uri", "ExperimentClient", diff --git a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py index 3dff6fc..5a60648 100644 --- a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py +++ b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py @@ -50,7 +50,7 @@ class Repository: key=dest_path, ) - def log_artifacts(self, local_dir, artifact_path): + def log_artifacts(self, local_dir: str, artifact_path: str) -> str: bucket = "submarine" dest_path = self.dest_path list_of_subfolder = self._list_artifact_subfolder(artifact_path) @@ -71,3 +71,13 @@ class Repository: bucket=bucket, key=os.path.join(upload_path, f), ) + return f"s3://{bucket}/{dest_path}" + + def delete_folder(self) -> None: + objects_to_delete = self.client.list_objects(Bucket="submarine", Prefix=self.dest_path) + if objects_to_delete.get("Contents") is not None: + delete_keys: dict = {"Objects": []} + delete_keys["Objects"] = [ + {"Key": k} for k in [obj["Key"] for obj in objects_to_delete.get("Contents")] + ] + self.client.delete_objects(Bucket="submarine", Delete=delete_keys) diff --git a/submarine-sdk/pysubmarine/submarine/models/client.py b/submarine-sdk/pysubmarine/submarine/models/client.py index 9cf655f..e633188 100644 --- a/submarine-sdk/pysubmarine/submarine/models/client.py +++ b/submarine-sdk/pysubmarine/submarine/models/client.py @@ -15,16 +15,12 @@ under the License. """ import os -import re -import tempfile import time import mlflow from mlflow.exceptions import MlflowException from mlflow.tracking import MlflowClient -from submarine.artifacts.repository import Repository - from .constant import ( AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, @@ -58,7 +54,6 @@ class ModelsClient: "tensorflow": mlflow.tensorflow.log_model, "keras": mlflow.keras.log_model, } - self.artifact_repo = Repository(get_job_id()) def start(self): """ @@ -109,27 +104,6 @@ class ModelsClient: else: raise MlflowException("No valid type of model has been matched") - def save_model_submarine(self, model_type, model, artifact_path, registered_model_name=None): - pattern = r"[0-9A-Za-z][0-9A-Za-z-_]*[0-9A-Za-z]|[0-9A-Za-z]" - if not re.fullmatch(pattern, artifact_path): - raise Exception( - "Artifact_path must only contains numbers, characters, hyphen and underscore. " - " Artifact_path must starts and ends with numbers or characters." - ) - with tempfile.TemporaryDirectory() as tempdir: - if model_type == "pytorch": - import submarine.models.pytorch - - submarine.models.pytorch.save_model(model, tempdir) - elif model_type == "tensorflow": - import submarine.models.tensorflow - - submarine.models.tensorflow.save_model(model, tempdir) - else: - raise Exception("No valid type of model has been matched to {}".format(model_type)) - self.artifact_repo.log_artifacts(tempdir, artifact_path) - # TODO for registering model () - def _get_or_create_experiment(self, experiment_name): """ Return the id of experiment. diff --git a/submarine-sdk/pysubmarine/submarine/store/__init__.py b/submarine-sdk/pysubmarine/submarine/store/__init__.py index 60412a9..458bc11 100644 --- a/submarine-sdk/pysubmarine/submarine/store/__init__.py +++ b/submarine-sdk/pysubmarine/submarine/store/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -DEFAULT_SUBMARINE_JDBC_URL = "mysql+pymysql://submarine:password@localhost:3306/submarine" + +DEFAULT_SUBMARINE_JDBC_URL = "mysql+pymysql://submarine:password@submarine-database:3306/submarine" __all__ = ["DEFAULT_SUBMARINE_JDBC_URL"] diff --git a/submarine-sdk/pysubmarine/submarine/store/__init__.py b/submarine-sdk/pysubmarine/submarine/store/tracking/__init__.py similarity index 85% copy from submarine-sdk/pysubmarine/submarine/store/__init__.py copy to submarine-sdk/pysubmarine/submarine/store/tracking/__init__.py index 60412a9..a6eb1b5 100644 --- a/submarine-sdk/pysubmarine/submarine/store/__init__.py +++ b/submarine-sdk/pysubmarine/submarine/store/tracking/__init__.py @@ -12,7 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -DEFAULT_SUBMARINE_JDBC_URL = "mysql+pymysql://submarine:password@localhost:3306/submarine" - -__all__ = ["DEFAULT_SUBMARINE_JDBC_URL"] diff --git a/submarine-sdk/pysubmarine/submarine/store/abstract_store.py b/submarine-sdk/pysubmarine/submarine/store/tracking/abstract_store.py similarity index 100% rename from submarine-sdk/pysubmarine/submarine/store/abstract_store.py rename to submarine-sdk/pysubmarine/submarine/store/tracking/abstract_store.py diff --git a/submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py b/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py similarity index 96% rename from submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py rename to submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py index 01adec5..23e8a8d 100644 --- a/submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py +++ b/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py @@ -19,9 +19,10 @@ from contextlib import contextmanager import sqlalchemy +from submarine.entities import Param from submarine.exceptions import SubmarineException -from submarine.store.abstract_store import AbstractStore from submarine.store.database.models import Base, SqlMetric, SqlParam +from submarine.store.tracking.abstract_store import AbstractStore from submarine.utils import extract_db_type_from_uri _logger = logging.getLogger(__name__) @@ -42,7 +43,7 @@ class SqlAlchemyStore(AbstractStore): :py:class:`submarine.store.database.models.SqlParam`. """ - def __init__(self, db_uri): + def __init__(self, db_uri: str) -> None: """ Create a database backed store. :param db_uri: The SQLAlchemy database URI string to connect to the database. See @@ -151,7 +152,7 @@ class SqlAlchemyStore(AbstractStore): except sqlalchemy.exc.IntegrityError: session.rollback() - def log_param(self, job_id, param): + def log_param(self, job_id: str, param: Param) -> None: with self.ManagedSessionMaker() as session: try: self._get_or_create( diff --git a/submarine-sdk/pysubmarine/submarine/tracking/client.py b/submarine-sdk/pysubmarine/submarine/tracking/client.py index 2ee9b09..982dfe8 100644 --- a/submarine-sdk/pysubmarine/submarine/tracking/client.py +++ b/submarine-sdk/pysubmarine/submarine/tracking/client.py @@ -12,30 +12,56 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import re +import tempfile import time import submarine +from submarine.artifacts.repository import Repository from submarine.entities import Metric, Param +from submarine.exceptions import SubmarineException from submarine.tracking import utils from submarine.utils.validation import validate_metric, validate_param +from .constant import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, S3_ENDPOINT_URL + class SubmarineClient(object): """ Client of an submarine Tracking Server that creates and manages experiments and runs. """ - def __init__(self, db_uri=None): + def __init__( + self, + db_uri: str = None, + s3_registry_uri: str = None, + aws_access_key_id: str = None, + aws_secret_access_key: str = None, + ) -> None: """ :param db_uri: Address of local or remote tracking server. If not provided, defaults to the service set by ``submarine.tracking.set_db_uri``. See `Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>`_ for more info. """ + os.environ["MLFLOW_S3_ENDPOINT_URL"] = s3_registry_uri or S3_ENDPOINT_URL + os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id or AWS_ACCESS_KEY_ID + os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key or AWS_SECRET_ACCESS_KEY + self.artifact_repo = Repository(utils.get_job_id()) self.db_uri = db_uri or submarine.get_db_uri() - self.store = utils.get_sqlalchemy_store(self.db_uri) + self.store = utils.get_tracking_sqlalchemy_store(self.db_uri) + self.model_registry = utils.get_model_registry_sqlalchemy_store(self.db_uri) - def log_metric(self, job_id, key, value, worker_index, timestamp=None, step=None): + def log_metric( + self, + job_id: str, + key: str, + value: float, + worker_index: str, + timestamp: int = None, + step: int = None, + ) -> None: """ Log a metric against the run ID. :param job_id: The job name to which the metric should be logged. @@ -53,7 +79,7 @@ class SubmarineClient(object): metric = Metric(key, value, worker_index, timestamp, step) self.store.log_metric(job_id, metric) - def log_param(self, job_id, key, value, worker_index): + def log_param(self, job_id: str, key: str, value: str, worker_index: str) -> None: """ Log a parameter against the job name. Value is converted to a string. :param job_id: The job name to which the parameter should be logged. @@ -64,3 +90,46 @@ class SubmarineClient(object): validate_param(key, value) param = Param(key, str(value), worker_index) self.store.log_param(job_id, param) + + def save_model( + self, model_type: str, model, artifact_path: str, registered_model_name: str = None + ) -> None: + """ + Save a model into the minio pod. + :param model_type: The type of the model. + :param model: Model. + :param artifact_path: Relative path of the artifact in the minio pod. + :param registered_model_name: If not None, register model into the model registry with + this name. If None, the model only be saved in minio pod. + """ + pattern = r"[0-9A-Za-z][0-9A-Za-z-_]*[0-9A-Za-z]|[0-9A-Za-z]" + if not re.fullmatch(pattern, artifact_path): + raise Exception( + "Artifact_path must only contains numbers, characters, hyphen and underscore. " + " Artifact_path must starts and ends with numbers or characters." + ) + with tempfile.TemporaryDirectory() as tempdir: + if model_type == "pytorch": + import submarine.models.pytorch + + submarine.models.pytorch.save_model(model, tempdir) + elif model_type == "tensorflow": + import submarine.models.tensorflow + + submarine.models.tensorflow.save_model(model, tempdir) + else: + raise Exception("No valid type of model has been matched to {}".format(model_type)) + source = self.artifact_repo.log_artifacts(tempdir, artifact_path) + + # Register model + if registered_model_name is not None: + try: + self.model_registry.get_registered_model(registered_model_name) + except SubmarineException: + self.model_registry.create_registered_model(name=registered_model_name) + self.model_registry.create_model_version( + name=registered_model_name, + source=source, + user_id="", # TODO(jeff-901): the user id is needed to be specified. + experiment_id=utils.get_job_id(), + ) diff --git a/submarine-sdk/pysubmarine/submarine/tracking/constant.py b/submarine-sdk/pysubmarine/submarine/tracking/constant.py new file mode 100644 index 0000000..201d89a --- /dev/null +++ b/submarine-sdk/pysubmarine/submarine/tracking/constant.py @@ -0,0 +1,20 @@ +""" + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +""" + +S3_ENDPOINT_URL = "http://submarine-minio-service:9000" +AWS_ACCESS_KEY_ID = "submarine_minio" +AWS_SECRET_ACCESS_KEY = "submarine_minio" diff --git a/submarine-sdk/pysubmarine/submarine/tracking/fluent.py b/submarine-sdk/pysubmarine/submarine/tracking/fluent.py index 8406ce7..aabe7ed 100644 --- a/submarine-sdk/pysubmarine/submarine/tracking/fluent.py +++ b/submarine-sdk/pysubmarine/submarine/tracking/fluent.py @@ -52,3 +52,15 @@ def log_metric(key, value, step=None): job_id = get_job_id() worker_index = get_worker_index() SubmarineClient().log_metric(job_id, key, value, worker_index, datetime.now(), step or 0) + + +def save_model(model_type: str, model, artifact_path: str, registered_model_name: str = None): + """ + Save a model into the minio pod. + :param model_type: The type of the model. + :param model: Model. + :param artifact_path: Relative path of the artifact in the minio pod. + :param registered_model_name: If none None, register model into the model registry with + this name. If None, the model only be saved in minio pod. + """ + SubmarineClient().save_model(model_type, model, artifact_path, registered_model_name) diff --git a/submarine-sdk/pysubmarine/submarine/tracking/utils.py b/submarine-sdk/pysubmarine/submarine/tracking/utils.py index ec0ec14..4a223e1 100644 --- a/submarine-sdk/pysubmarine/submarine/tracking/utils.py +++ b/submarine-sdk/pysubmarine/submarine/tracking/utils.py @@ -19,7 +19,6 @@ import json import os import uuid -from submarine.store.sqlalchemy_store import SqlAlchemyStore from submarine.utils import env _TRACKING_URI_ENV_VAR = "SUBMARINE_TRACKING_URI" @@ -88,5 +87,13 @@ def get_worker_index(): return worker_index -def get_sqlalchemy_store(store_uri): +def get_tracking_sqlalchemy_store(store_uri: str): + from submarine.store.tracking.sqlalchemy_store import SqlAlchemyStore + + return SqlAlchemyStore(store_uri) + + +def get_model_registry_sqlalchemy_store(store_uri: str): + from submarine.store.model_registry.sqlalchemy_store import SqlAlchemyStore + return SqlAlchemyStore(store_uri) diff --git a/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py b/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py index dbdacdc..3104800 100644 --- a/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py +++ b/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py @@ -22,7 +22,7 @@ import submarine from submarine.entities import Metric, Param from submarine.store.database import models from submarine.store.database.models import SqlExperiment, SqlMetric, SqlParam -from submarine.store.sqlalchemy_store import SqlAlchemyStore +from submarine.store.tracking.sqlalchemy_store import SqlAlchemyStore JOB_ID = "application_123456789" diff --git a/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py b/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py index 7410e16..59feefa 100644 --- a/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py +++ b/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py @@ -18,13 +18,18 @@ from datetime import datetime from os import environ import pytest +import tensorflow import submarine +from submarine.artifacts.repository import Repository from submarine.store.database import models from submarine.store.database.models import SqlExperiment, SqlMetric, SqlParam -from submarine.store.sqlalchemy_store import SqlAlchemyStore +from submarine.tracking.client import SubmarineClient + +from .tf_model import LinearNNModel JOB_ID = "application_123456789" +MLFLOW_S3_ENDPOINT_URL = "http://localhost:9000" @pytest.mark.e2e @@ -35,7 +40,16 @@ class TestTracking(unittest.TestCase): "mysql+pymysql://submarine_test:password_test@localhost:3306/submarine_test" ) self.db_uri = submarine.get_db_uri() + self.client = SubmarineClient( + db_uri=self.db_uri, + s3_registry_uri=MLFLOW_S3_ENDPOINT_URL, + ) + from submarine.store.tracking.sqlalchemy_store import SqlAlchemyStore + self.store = SqlAlchemyStore(self.db_uri) + from submarine.store.model_registry.sqlalchemy_store import SqlAlchemyStore + + self.model_registry = SqlAlchemyStore(self.db_uri) # TODO: use submarine.tracking.fluent to support experiment create with self.store.ManagedSessionMaker() as session: instance = SqlExperiment( @@ -52,6 +66,10 @@ class TestTracking(unittest.TestCase): def tearDown(self): submarine.set_db_uri(None) models.Base.metadata.drop_all(self.store.engine) + environ["MLFLOW_S3_ENDPOINT_URL"] = MLFLOW_S3_ENDPOINT_URL + environ["AWS_ACCESS_KEY_ID"] = "submarine_minio" + environ["AWS_SECRET_ACCESS_KEY"] = "submarine_minio" + Repository(JOB_ID).delete_folder() def test_log_param(self): submarine.log_param("name_1", "a") @@ -73,3 +91,21 @@ class TestTracking(unittest.TestCase): assert metrics[0].value == 5 assert metrics[0].id == JOB_ID assert metrics[1].value == 6 + + @pytest.mark.skipif(tensorflow.version.VERSION < "2.0", reason="using tensorflow 2") + def test_save_model(self): + input_arr = tensorflow.random.uniform((1, 5)) + model = LinearNNModel() + model(input_arr) + registered_model_name = "registerd_model_name" + self.client.save_model("tensorflow", model, "name_1", registered_model_name) + self.client.save_model("tensorflow", model, "name_2", registered_model_name) + # Validate model_versions + model_versions = self.model_registry.list_model_versions(registered_model_name) + assert len(model_versions) == 2 + assert model_versions[0].name == registered_model_name + assert model_versions[0].version == 1 + assert model_versions[0].source == f"s3://submarine/{JOB_ID}/name_1/1" + assert model_versions[1].name == registered_model_name + assert model_versions[1].version == 2 + assert model_versions[1].source == f"s3://submarine/{JOB_ID}/name_2/1" diff --git a/submarine-sdk/pysubmarine/tests/tracking/test_utils.py b/submarine-sdk/pysubmarine/tests/tracking/test_utils.py index 2fc3392..121691d 100644 --- a/submarine-sdk/pysubmarine/tests/tracking/test_utils.py +++ b/submarine-sdk/pysubmarine/tests/tracking/test_utils.py @@ -18,12 +18,12 @@ import os import mock from submarine.store import DEFAULT_SUBMARINE_JDBC_URL -from submarine.store.sqlalchemy_store import SqlAlchemyStore +from submarine.store.tracking.sqlalchemy_store import SqlAlchemyStore from submarine.tracking.utils import ( _JOB_ID_ENV_VAR, _TRACKING_URI_ENV_VAR, get_job_id, - get_sqlalchemy_store, + get_tracking_sqlalchemy_store, ) @@ -35,14 +35,14 @@ def test_get_job_id(): assert get_job_id() == "application_12346789" -def test_get_sqlalchemy_store(): +def test_get_tracking_sqlalchemy_store(): patch_create_engine = mock.patch("sqlalchemy.create_engine") uri = DEFAULT_SUBMARINE_JDBC_URL env = {_TRACKING_URI_ENV_VAR: uri} with mock.patch.dict(os.environ, env), patch_create_engine as mock_create_engine, mock.patch( - "submarine.store.sqlalchemy_store.SqlAlchemyStore._initialize_tables" + "submarine.store.tracking.sqlalchemy_store.SqlAlchemyStore._initialize_tables" ): - store = get_sqlalchemy_store(uri) + store = get_tracking_sqlalchemy_store(uri) assert isinstance(store, SqlAlchemyStore) assert store.db_uri == uri mock_create_engine.assert_called_once_with(uri, pool_pre_ping=True) diff --git a/submarine-sdk/pysubmarine/tests/tracking/tf_model.py b/submarine-sdk/pysubmarine/tests/tracking/tf_model.py new file mode 100644 index 0000000..6b598b9 --- /dev/null +++ b/submarine-sdk/pysubmarine/tests/tracking/tf_model.py @@ -0,0 +1,27 @@ +""" + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +""" +import tensorflow as tf + + +class LinearNNModel(tf.keras.Model): + def __init__(self): + super(LinearNNModel, self).__init__() + self.dense1 = tf.keras.layers.Dense(1, activation=tf.nn.relu) # One in and one out + + def call(self, x): + y_pred = self.dense1(x) + return y_pred diff --git a/website/docs/userDocs/submarine-sdk/tracking.md b/website/docs/userDocs/submarine-sdk/tracking.md index afacf39..f774753 100644 --- a/website/docs/userDocs/submarine-sdk/tracking.md +++ b/website/docs/userDocs/submarine-sdk/tracking.md @@ -44,7 +44,7 @@ set the tracking URI. You can also set the SUBMARINE_TRACKING_URI environment va > **Parameters** - **uri** \- Submarine record data to Mysql server. The database URL is expected in the format ``<dialect>+<driver>://<username>:<password>@<host>:<port>/<database>``. - By default it's `mysql+pymysql://submarine:password@localhost:3306/submarine`. + By default it's `mysql+pymysql://submarine:password@submarine-database:3306/submarine`. More detail : [SQLAlchemy docs](https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls) ### `submarine.log_param(key: str, value: str) -> None` diff --git a/website/versioned_docs/version-0.6.0/userDocs/submarine-sdk/tracking.md b/website/versioned_docs/version-0.6.0/userDocs/submarine-sdk/tracking.md index afacf39..f774753 100644 --- a/website/versioned_docs/version-0.6.0/userDocs/submarine-sdk/tracking.md +++ b/website/versioned_docs/version-0.6.0/userDocs/submarine-sdk/tracking.md @@ -44,7 +44,7 @@ set the tracking URI. You can also set the SUBMARINE_TRACKING_URI environment va > **Parameters** - **uri** \- Submarine record data to Mysql server. The database URL is expected in the format ``<dialect>+<driver>://<username>:<password>@<host>:<port>/<database>``. - By default it's `mysql+pymysql://submarine:password@localhost:3306/submarine`. + By default it's `mysql+pymysql://submarine:password@submarine-database:3306/submarine`. More detail : [SQLAlchemy docs](https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls) ### `submarine.log_param(key: str, value: str) -> None` --------------------------------------------------------------------- To unsubscribe, e-mail: dev-unsubscr...@submarine.apache.org For additional commands, e-mail: dev-h...@submarine.apache.org