You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by pi...@apache.org on 2021/10/24 18:36:05 UTC

[submarine] branch master updated: SUBMARINE-1005. Register model version when saving model.

This is an automated email from the ASF dual-hosted git repository.

pingsutw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git


The following commit(s) were added to refs/heads/master by this push:
     new 535770e  SUBMARINE-1005. Register model version when saving model.
535770e is described below

commit 535770eaa1935d04d5bea13d326406e7c9286d77
Author: jeff-901 <b0...@ntu.edu.tw>
AuthorDate: Sun Oct 17 11:38:51 2021 +0800

    SUBMARINE-1005. Register model version when saving model.
    
    ### What is this PR for?
    Migrate save model function to SubmarineClient.
    Implement logic of register model in save model function.
    
    ### What type of PR is it?
    Feature
    
    ### Todos
    
    ### What is the Jira issue?
    https://issues.apache.org/jira/browse/SUBMARINE-1005
    
    ### How should this be tested?
    github action.
    ### Screenshots (if appropriate)
    
    ### Questions:
    * Do the license files need updating? No
    * Are there breaking changes for older versions? No
    * Does this need new documentation? No
    
    Author: jeff-901 <b0...@ntu.edu.tw>
    
    Signed-off-by: Kevin <pi...@apache.org>
    
    Closes #772 from jeff-901/SUBMARINE-1005 and squashes the following commits:
    
    3371b5fe [jeff-901] fix duplicate
    7a6cbe91 [jeff-901] fix typo
    091ae50f [jeff-901] edit document and fix test
    d16c93a1 [jeff-901] fix bugs
    0f4a9069 [jeff-901] fix model client
    c629f25b [jeff-901] add test and remove duplicate code
    b70ceedb [jeff-901] add mypy syntax
    03f467e7 [jeff-901] checkstyle
    e4e13834 [jeff-901] refactor alchemy_store
    36631d2c [jeff-901] add save model in submarine client
---
 submarine-sdk/pysubmarine/submarine/__init__.py    |  2 +
 .../pysubmarine/submarine/artifacts/repository.py  | 12 +++-
 .../pysubmarine/submarine/models/client.py         | 26 --------
 .../pysubmarine/submarine/store/__init__.py        |  3 +-
 .../submarine/store/{ => tracking}/__init__.py     |  4 --
 .../store/{ => tracking}/abstract_store.py         |  0
 .../store/{ => tracking}/sqlalchemy_store.py       |  7 +-
 .../pysubmarine/submarine/tracking/client.py       | 77 ++++++++++++++++++++--
 .../pysubmarine/submarine/tracking/constant.py     | 20 ++++++
 .../pysubmarine/submarine/tracking/fluent.py       | 12 ++++
 .../pysubmarine/submarine/tracking/utils.py        | 11 +++-
 .../tests/store/tracking/test_sqlalchemy_store.py  |  2 +-
 .../pysubmarine/tests/tracking/test_tracking.py    | 38 ++++++++++-
 .../pysubmarine/tests/tracking/test_utils.py       | 10 +--
 .../pysubmarine/tests/tracking/tf_model.py         | 27 ++++++++
 website/docs/userDocs/submarine-sdk/tracking.md    |  2 +-
 .../userDocs/submarine-sdk/tracking.md             |  2 +-
 17 files changed, 205 insertions(+), 50 deletions(-)

diff --git a/submarine-sdk/pysubmarine/submarine/__init__.py b/submarine-sdk/pysubmarine/submarine/__init__.py
index 85519e8..0554922 100644
--- a/submarine-sdk/pysubmarine/submarine/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/__init__.py
@@ -20,12 +20,14 @@ from submarine.models.client import ModelsClient
 
 log_param = submarine.tracking.fluent.log_param
 log_metric = submarine.tracking.fluent.log_metric
+save_model = submarine.tracking.fluent.save_model
 set_db_uri = utils.set_db_uri
 get_db_uri = utils.get_db_uri
 
 __all__ = [
     "log_metric",
     "log_param",
+    "save_model",
     "set_db_uri",
     "get_db_uri",
     "ExperimentClient",
diff --git a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
index 3dff6fc..5a60648 100644
--- a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
+++ b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
@@ -50,7 +50,7 @@ class Repository:
             key=dest_path,
         )
 
-    def log_artifacts(self, local_dir, artifact_path):
+    def log_artifacts(self, local_dir: str, artifact_path: str) -> str:
         bucket = "submarine"
         dest_path = self.dest_path
         list_of_subfolder = self._list_artifact_subfolder(artifact_path)
@@ -71,3 +71,13 @@ class Repository:
                     bucket=bucket,
                     key=os.path.join(upload_path, f),
                 )
+        return f"s3://{bucket}/{dest_path}"
+
+    def delete_folder(self) -> None:
+        objects_to_delete = self.client.list_objects(Bucket="submarine", Prefix=self.dest_path)
+        if objects_to_delete.get("Contents") is not None:
+            delete_keys: dict = {"Objects": []}
+            delete_keys["Objects"] = [
+                {"Key": k} for k in [obj["Key"] for obj in objects_to_delete.get("Contents")]
+            ]
+            self.client.delete_objects(Bucket="submarine", Delete=delete_keys)
diff --git a/submarine-sdk/pysubmarine/submarine/models/client.py b/submarine-sdk/pysubmarine/submarine/models/client.py
index 9cf655f..e633188 100644
--- a/submarine-sdk/pysubmarine/submarine/models/client.py
+++ b/submarine-sdk/pysubmarine/submarine/models/client.py
@@ -15,16 +15,12 @@
  under the License.
 """
 import os
-import re
-import tempfile
 import time
 
 import mlflow
 from mlflow.exceptions import MlflowException
 from mlflow.tracking import MlflowClient
 
-from submarine.artifacts.repository import Repository
-
 from .constant import (
     AWS_ACCESS_KEY_ID,
     AWS_SECRET_ACCESS_KEY,
@@ -58,7 +54,6 @@ class ModelsClient:
             "tensorflow": mlflow.tensorflow.log_model,
             "keras": mlflow.keras.log_model,
         }
-        self.artifact_repo = Repository(get_job_id())
 
     def start(self):
         """
@@ -109,27 +104,6 @@ class ModelsClient:
             else:
                 raise MlflowException("No valid type of model has been matched")
 
-    def save_model_submarine(self, model_type, model, artifact_path, registered_model_name=None):
-        pattern = r"[0-9A-Za-z][0-9A-Za-z-_]*[0-9A-Za-z]|[0-9A-Za-z]"
-        if not re.fullmatch(pattern, artifact_path):
-            raise Exception(
-                "Artifact_path must only contains numbers, characters, hyphen and underscore.      "
-                "        Artifact_path must starts and ends with numbers or characters."
-            )
-        with tempfile.TemporaryDirectory() as tempdir:
-            if model_type == "pytorch":
-                import submarine.models.pytorch
-
-                submarine.models.pytorch.save_model(model, tempdir)
-            elif model_type == "tensorflow":
-                import submarine.models.tensorflow
-
-                submarine.models.tensorflow.save_model(model, tempdir)
-            else:
-                raise Exception("No valid type of model has been matched to {}".format(model_type))
-            self.artifact_repo.log_artifacts(tempdir, artifact_path)
-        # TODO for registering model ()
-
     def _get_or_create_experiment(self, experiment_name):
         """
         Return the id of experiment.
diff --git a/submarine-sdk/pysubmarine/submarine/store/__init__.py b/submarine-sdk/pysubmarine/submarine/store/__init__.py
index 60412a9..458bc11 100644
--- a/submarine-sdk/pysubmarine/submarine/store/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/store/__init__.py
@@ -13,6 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-DEFAULT_SUBMARINE_JDBC_URL = "mysql+pymysql://submarine:password@localhost:3306/submarine"
+
+DEFAULT_SUBMARINE_JDBC_URL = "mysql+pymysql://submarine:password@submarine-database:3306/submarine"
 
 __all__ = ["DEFAULT_SUBMARINE_JDBC_URL"]
diff --git a/submarine-sdk/pysubmarine/submarine/store/__init__.py b/submarine-sdk/pysubmarine/submarine/store/tracking/__init__.py
similarity index 85%
copy from submarine-sdk/pysubmarine/submarine/store/__init__.py
copy to submarine-sdk/pysubmarine/submarine/store/tracking/__init__.py
index 60412a9..a6eb1b5 100644
--- a/submarine-sdk/pysubmarine/submarine/store/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/store/tracking/__init__.py
@@ -12,7 +12,3 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-DEFAULT_SUBMARINE_JDBC_URL = "mysql+pymysql://submarine:password@localhost:3306/submarine"
-
-__all__ = ["DEFAULT_SUBMARINE_JDBC_URL"]
diff --git a/submarine-sdk/pysubmarine/submarine/store/abstract_store.py b/submarine-sdk/pysubmarine/submarine/store/tracking/abstract_store.py
similarity index 100%
rename from submarine-sdk/pysubmarine/submarine/store/abstract_store.py
rename to submarine-sdk/pysubmarine/submarine/store/tracking/abstract_store.py
diff --git a/submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py b/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py
similarity index 96%
rename from submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py
rename to submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py
index 01adec5..23e8a8d 100644
--- a/submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py
@@ -19,9 +19,10 @@ from contextlib import contextmanager
 
 import sqlalchemy
 
+from submarine.entities import Param
 from submarine.exceptions import SubmarineException
-from submarine.store.abstract_store import AbstractStore
 from submarine.store.database.models import Base, SqlMetric, SqlParam
+from submarine.store.tracking.abstract_store import AbstractStore
 from submarine.utils import extract_db_type_from_uri
 
 _logger = logging.getLogger(__name__)
@@ -42,7 +43,7 @@ class SqlAlchemyStore(AbstractStore):
     :py:class:`submarine.store.database.models.SqlParam`.
     """
 
-    def __init__(self, db_uri):
+    def __init__(self, db_uri: str) -> None:
         """
         Create a database backed store.
         :param db_uri: The SQLAlchemy database URI string to connect to the database. See
@@ -151,7 +152,7 @@ class SqlAlchemyStore(AbstractStore):
             except sqlalchemy.exc.IntegrityError:
                 session.rollback()
 
-    def log_param(self, job_id, param):
+    def log_param(self, job_id: str, param: Param) -> None:
         with self.ManagedSessionMaker() as session:
             try:
                 self._get_or_create(
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/client.py b/submarine-sdk/pysubmarine/submarine/tracking/client.py
index 2ee9b09..982dfe8 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/client.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/client.py
@@ -12,30 +12,56 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import os
+import re
+import tempfile
 import time
 
 import submarine
+from submarine.artifacts.repository import Repository
 from submarine.entities import Metric, Param
+from submarine.exceptions import SubmarineException
 from submarine.tracking import utils
 from submarine.utils.validation import validate_metric, validate_param
 
+from .constant import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, S3_ENDPOINT_URL
+
 
 class SubmarineClient(object):
     """
     Client of an submarine Tracking Server that creates and manages experiments and runs.
     """
 
-    def __init__(self, db_uri=None):
+    def __init__(
+        self,
+        db_uri: str = None,
+        s3_registry_uri: str = None,
+        aws_access_key_id: str = None,
+        aws_secret_access_key: str = None,
+    ) -> None:
         """
         :param db_uri: Address of local or remote tracking server. If not provided, defaults
                              to the service set by ``submarine.tracking.set_db_uri``. See
                              `Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>`_
                              for more info.
         """
+        os.environ["MLFLOW_S3_ENDPOINT_URL"] = s3_registry_uri or S3_ENDPOINT_URL
+        os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id or AWS_ACCESS_KEY_ID
+        os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key or AWS_SECRET_ACCESS_KEY
+        self.artifact_repo = Repository(utils.get_job_id())
         self.db_uri = db_uri or submarine.get_db_uri()
-        self.store = utils.get_sqlalchemy_store(self.db_uri)
+        self.store = utils.get_tracking_sqlalchemy_store(self.db_uri)
+        self.model_registry = utils.get_model_registry_sqlalchemy_store(self.db_uri)
 
-    def log_metric(self, job_id, key, value, worker_index, timestamp=None, step=None):
+    def log_metric(
+        self,
+        job_id: str,
+        key: str,
+        value: float,
+        worker_index: str,
+        timestamp: int = None,
+        step: int = None,
+    ) -> None:
         """
         Log a metric against the run ID.
         :param job_id: The job name to which the metric should be logged.
@@ -53,7 +79,7 @@ class SubmarineClient(object):
         metric = Metric(key, value, worker_index, timestamp, step)
         self.store.log_metric(job_id, metric)
 
-    def log_param(self, job_id, key, value, worker_index):
+    def log_param(self, job_id: str, key: str, value: str, worker_index: str) -> None:
         """
         Log a parameter against the job name. Value is converted to a string.
         :param job_id: The job name to which the parameter should be logged.
@@ -64,3 +90,46 @@ class SubmarineClient(object):
         validate_param(key, value)
         param = Param(key, str(value), worker_index)
         self.store.log_param(job_id, param)
+
+    def save_model(
+        self, model_type: str, model, artifact_path: str, registered_model_name: str = None
+    ) -> None:
+        """
+        Save a model into the minio pod.
+        :param model_type: The type of the model.
+        :param model: Model.
+        :param artifact_path: Relative path of the artifact in the minio pod.
+        :param registered_model_name: If not None, register model into the model registry with
+                                      this name. If None, the model only be saved in minio pod.
+        """
+        pattern = r"[0-9A-Za-z][0-9A-Za-z-_]*[0-9A-Za-z]|[0-9A-Za-z]"
+        if not re.fullmatch(pattern, artifact_path):
+            raise Exception(
+                "Artifact_path must only contains numbers, characters, hyphen and underscore.      "
+                "        Artifact_path must starts and ends with numbers or characters."
+            )
+        with tempfile.TemporaryDirectory() as tempdir:
+            if model_type == "pytorch":
+                import submarine.models.pytorch
+
+                submarine.models.pytorch.save_model(model, tempdir)
+            elif model_type == "tensorflow":
+                import submarine.models.tensorflow
+
+                submarine.models.tensorflow.save_model(model, tempdir)
+            else:
+                raise Exception("No valid type of model has been matched to {}".format(model_type))
+            source = self.artifact_repo.log_artifacts(tempdir, artifact_path)
+
+        # Register model
+        if registered_model_name is not None:
+            try:
+                self.model_registry.get_registered_model(registered_model_name)
+            except SubmarineException:
+                self.model_registry.create_registered_model(name=registered_model_name)
+            self.model_registry.create_model_version(
+                name=registered_model_name,
+                source=source,
+                user_id="",  # TODO(jeff-901): the user id is needed to be specified.
+                experiment_id=utils.get_job_id(),
+            )
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/constant.py b/submarine-sdk/pysubmarine/submarine/tracking/constant.py
new file mode 100644
index 0000000..201d89a
--- /dev/null
+++ b/submarine-sdk/pysubmarine/submarine/tracking/constant.py
@@ -0,0 +1,20 @@
+"""
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied.  See the License for the
+ specific language governing permissions and limitations
+ under the License.
+"""
+
+S3_ENDPOINT_URL = "http://submarine-minio-service:9000"
+AWS_ACCESS_KEY_ID = "submarine_minio"
+AWS_SECRET_ACCESS_KEY = "submarine_minio"
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/fluent.py b/submarine-sdk/pysubmarine/submarine/tracking/fluent.py
index 8406ce7..aabe7ed 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/fluent.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/fluent.py
@@ -52,3 +52,15 @@ def log_metric(key, value, step=None):
     job_id = get_job_id()
     worker_index = get_worker_index()
     SubmarineClient().log_metric(job_id, key, value, worker_index, datetime.now(), step or 0)
+
+
+def save_model(model_type: str, model, artifact_path: str, registered_model_name: str = None):
+    """
+    Save a model into the minio pod.
+    :param model_type: The type of the model.
+    :param model: Model.
+    :param artifact_path: Relative path of the artifact in the minio pod.
+    :param registered_model_name: If none None, register model into the model registry with
+                                  this name. If None, the model only be saved in minio pod.
+    """
+    SubmarineClient().save_model(model_type, model, artifact_path, registered_model_name)
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/utils.py b/submarine-sdk/pysubmarine/submarine/tracking/utils.py
index ec0ec14..4a223e1 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/utils.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/utils.py
@@ -19,7 +19,6 @@ import json
 import os
 import uuid
 
-from submarine.store.sqlalchemy_store import SqlAlchemyStore
 from submarine.utils import env
 
 _TRACKING_URI_ENV_VAR = "SUBMARINE_TRACKING_URI"
@@ -88,5 +87,13 @@ def get_worker_index():
     return worker_index
 
 
-def get_sqlalchemy_store(store_uri):
+def get_tracking_sqlalchemy_store(store_uri: str):
+    from submarine.store.tracking.sqlalchemy_store import SqlAlchemyStore
+
+    return SqlAlchemyStore(store_uri)
+
+
+def get_model_registry_sqlalchemy_store(store_uri: str):
+    from submarine.store.model_registry.sqlalchemy_store import SqlAlchemyStore
+
     return SqlAlchemyStore(store_uri)
diff --git a/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py b/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py
index dbdacdc..3104800 100644
--- a/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py
@@ -22,7 +22,7 @@ import submarine
 from submarine.entities import Metric, Param
 from submarine.store.database import models
 from submarine.store.database.models import SqlExperiment, SqlMetric, SqlParam
-from submarine.store.sqlalchemy_store import SqlAlchemyStore
+from submarine.store.tracking.sqlalchemy_store import SqlAlchemyStore
 
 JOB_ID = "application_123456789"
 
diff --git a/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py b/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
index 7410e16..59feefa 100644
--- a/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
+++ b/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
@@ -18,13 +18,18 @@ from datetime import datetime
 from os import environ
 
 import pytest
+import tensorflow
 
 import submarine
+from submarine.artifacts.repository import Repository
 from submarine.store.database import models
 from submarine.store.database.models import SqlExperiment, SqlMetric, SqlParam
-from submarine.store.sqlalchemy_store import SqlAlchemyStore
+from submarine.tracking.client import SubmarineClient
+
+from .tf_model import LinearNNModel
 
 JOB_ID = "application_123456789"
+MLFLOW_S3_ENDPOINT_URL = "http://localhost:9000"
 
 
 @pytest.mark.e2e
@@ -35,7 +40,16 @@ class TestTracking(unittest.TestCase):
             "mysql+pymysql://submarine_test:password_test@localhost:3306/submarine_test"
         )
         self.db_uri = submarine.get_db_uri()
+        self.client = SubmarineClient(
+            db_uri=self.db_uri,
+            s3_registry_uri=MLFLOW_S3_ENDPOINT_URL,
+        )
+        from submarine.store.tracking.sqlalchemy_store import SqlAlchemyStore
+
         self.store = SqlAlchemyStore(self.db_uri)
+        from submarine.store.model_registry.sqlalchemy_store import SqlAlchemyStore
+
+        self.model_registry = SqlAlchemyStore(self.db_uri)
         # TODO: use submarine.tracking.fluent to support experiment create
         with self.store.ManagedSessionMaker() as session:
             instance = SqlExperiment(
@@ -52,6 +66,10 @@ class TestTracking(unittest.TestCase):
     def tearDown(self):
         submarine.set_db_uri(None)
         models.Base.metadata.drop_all(self.store.engine)
+        environ["MLFLOW_S3_ENDPOINT_URL"] = MLFLOW_S3_ENDPOINT_URL
+        environ["AWS_ACCESS_KEY_ID"] = "submarine_minio"
+        environ["AWS_SECRET_ACCESS_KEY"] = "submarine_minio"
+        Repository(JOB_ID).delete_folder()
 
     def test_log_param(self):
         submarine.log_param("name_1", "a")
@@ -73,3 +91,21 @@ class TestTracking(unittest.TestCase):
             assert metrics[0].value == 5
             assert metrics[0].id == JOB_ID
             assert metrics[1].value == 6
+
+    @pytest.mark.skipif(tensorflow.version.VERSION < "2.0", reason="using tensorflow 2")
+    def test_save_model(self):
+        input_arr = tensorflow.random.uniform((1, 5))
+        model = LinearNNModel()
+        model(input_arr)
+        registered_model_name = "registerd_model_name"
+        self.client.save_model("tensorflow", model, "name_1", registered_model_name)
+        self.client.save_model("tensorflow", model, "name_2", registered_model_name)
+        # Validate model_versions
+        model_versions = self.model_registry.list_model_versions(registered_model_name)
+        assert len(model_versions) == 2
+        assert model_versions[0].name == registered_model_name
+        assert model_versions[0].version == 1
+        assert model_versions[0].source == f"s3://submarine/{JOB_ID}/name_1/1"
+        assert model_versions[1].name == registered_model_name
+        assert model_versions[1].version == 2
+        assert model_versions[1].source == f"s3://submarine/{JOB_ID}/name_2/1"
diff --git a/submarine-sdk/pysubmarine/tests/tracking/test_utils.py b/submarine-sdk/pysubmarine/tests/tracking/test_utils.py
index 2fc3392..121691d 100644
--- a/submarine-sdk/pysubmarine/tests/tracking/test_utils.py
+++ b/submarine-sdk/pysubmarine/tests/tracking/test_utils.py
@@ -18,12 +18,12 @@ import os
 import mock
 
 from submarine.store import DEFAULT_SUBMARINE_JDBC_URL
-from submarine.store.sqlalchemy_store import SqlAlchemyStore
+from submarine.store.tracking.sqlalchemy_store import SqlAlchemyStore
 from submarine.tracking.utils import (
     _JOB_ID_ENV_VAR,
     _TRACKING_URI_ENV_VAR,
     get_job_id,
-    get_sqlalchemy_store,
+    get_tracking_sqlalchemy_store,
 )
 
 
@@ -35,14 +35,14 @@ def test_get_job_id():
         assert get_job_id() == "application_12346789"
 
 
-def test_get_sqlalchemy_store():
+def test_get_tracking_sqlalchemy_store():
     patch_create_engine = mock.patch("sqlalchemy.create_engine")
     uri = DEFAULT_SUBMARINE_JDBC_URL
     env = {_TRACKING_URI_ENV_VAR: uri}
     with mock.patch.dict(os.environ, env), patch_create_engine as mock_create_engine, mock.patch(
-        "submarine.store.sqlalchemy_store.SqlAlchemyStore._initialize_tables"
+        "submarine.store.tracking.sqlalchemy_store.SqlAlchemyStore._initialize_tables"
     ):
-        store = get_sqlalchemy_store(uri)
+        store = get_tracking_sqlalchemy_store(uri)
         assert isinstance(store, SqlAlchemyStore)
         assert store.db_uri == uri
     mock_create_engine.assert_called_once_with(uri, pool_pre_ping=True)
diff --git a/submarine-sdk/pysubmarine/tests/tracking/tf_model.py b/submarine-sdk/pysubmarine/tests/tracking/tf_model.py
new file mode 100644
index 0000000..6b598b9
--- /dev/null
+++ b/submarine-sdk/pysubmarine/tests/tracking/tf_model.py
@@ -0,0 +1,27 @@
+"""
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied.  See the License for the
+ specific language governing permissions and limitations
+ under the License.
+"""
+import tensorflow as tf
+
+
+class LinearNNModel(tf.keras.Model):
+    def __init__(self):
+        super(LinearNNModel, self).__init__()
+        self.dense1 = tf.keras.layers.Dense(1, activation=tf.nn.relu)  # One in and one out
+
+    def call(self, x):
+        y_pred = self.dense1(x)
+        return y_pred
diff --git a/website/docs/userDocs/submarine-sdk/tracking.md b/website/docs/userDocs/submarine-sdk/tracking.md
index afacf39..f774753 100644
--- a/website/docs/userDocs/submarine-sdk/tracking.md
+++ b/website/docs/userDocs/submarine-sdk/tracking.md
@@ -44,7 +44,7 @@ set the tracking URI. You can also set the SUBMARINE_TRACKING_URI environment va
 
 > **Parameters**
   - **uri** \- Submarine record data to Mysql server. The database URL is expected in the format ``<dialect>+<driver>://<username>:<password>@<host>:<port>/<database>``.
-  By default it's `mysql+pymysql://submarine:password@localhost:3306/submarine`.
+  By default it's `mysql+pymysql://submarine:password@submarine-database:3306/submarine`.
   More detail : [SQLAlchemy docs](https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls)
 
 ### `submarine.log_param(key: str, value: str) -> None`
diff --git a/website/versioned_docs/version-0.6.0/userDocs/submarine-sdk/tracking.md b/website/versioned_docs/version-0.6.0/userDocs/submarine-sdk/tracking.md
index afacf39..f774753 100644
--- a/website/versioned_docs/version-0.6.0/userDocs/submarine-sdk/tracking.md
+++ b/website/versioned_docs/version-0.6.0/userDocs/submarine-sdk/tracking.md
@@ -44,7 +44,7 @@ set the tracking URI. You can also set the SUBMARINE_TRACKING_URI environment va
 
 > **Parameters**
   - **uri** \- Submarine record data to Mysql server. The database URL is expected in the format ``<dialect>+<driver>://<username>:<password>@<host>:<port>/<database>``.
-  By default it's `mysql+pymysql://submarine:password@localhost:3306/submarine`.
+  By default it's `mysql+pymysql://submarine:password@submarine-database:3306/submarine`.
   More detail : [SQLAlchemy docs](https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls)
 
 ### `submarine.log_param(key: str, value: str) -> None`

---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org