You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by pi...@apache.org on 2021/10/24 18:36:05 UTC
[submarine] branch master updated: SUBMARINE-1005. Register model
version when saving model.
This is an automated email from the ASF dual-hosted git repository.
pingsutw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push:
new 535770e SUBMARINE-1005. Register model version when saving model.
535770e is described below
commit 535770eaa1935d04d5bea13d326406e7c9286d77
Author: jeff-901 <b0...@ntu.edu.tw>
AuthorDate: Sun Oct 17 11:38:51 2021 +0800
SUBMARINE-1005. Register model version when saving model.
### What is this PR for?
Migrate save model function to SubmarineClient.
Implement logic of register model in save model function.
### What type of PR is it?
Feature
### Todos
### What is the Jira issue?
https://issues.apache.org/jira/browse/SUBMARINE-1005
### How should this be tested?
github action.
### Screenshots (if appropriate)
### Questions:
* Do the license files need updating? No
* Are there breaking changes for older versions? No
* Does this need new documentation? No
Author: jeff-901 <b0...@ntu.edu.tw>
Signed-off-by: Kevin <pi...@apache.org>
Closes #772 from jeff-901/SUBMARINE-1005 and squashes the following commits:
3371b5fe [jeff-901] fix duplicate
7a6cbe91 [jeff-901] fix typo
091ae50f [jeff-901] edit document and fix test
d16c93a1 [jeff-901] fix bugs
0f4a9069 [jeff-901] fix model client
c629f25b [jeff-901] add test and remove duplicate code
b70ceedb [jeff-901] add mypy syntax
03f467e7 [jeff-901] checkstyle
e4e13834 [jeff-901] refactor alchemy_store
36631d2c [jeff-901] add save model in submarine client
---
submarine-sdk/pysubmarine/submarine/__init__.py | 2 +
.../pysubmarine/submarine/artifacts/repository.py | 12 +++-
.../pysubmarine/submarine/models/client.py | 26 --------
.../pysubmarine/submarine/store/__init__.py | 3 +-
.../submarine/store/{ => tracking}/__init__.py | 4 --
.../store/{ => tracking}/abstract_store.py | 0
.../store/{ => tracking}/sqlalchemy_store.py | 7 +-
.../pysubmarine/submarine/tracking/client.py | 77 ++++++++++++++++++++--
.../pysubmarine/submarine/tracking/constant.py | 20 ++++++
.../pysubmarine/submarine/tracking/fluent.py | 12 ++++
.../pysubmarine/submarine/tracking/utils.py | 11 +++-
.../tests/store/tracking/test_sqlalchemy_store.py | 2 +-
.../pysubmarine/tests/tracking/test_tracking.py | 38 ++++++++++-
.../pysubmarine/tests/tracking/test_utils.py | 10 +--
.../pysubmarine/tests/tracking/tf_model.py | 27 ++++++++
website/docs/userDocs/submarine-sdk/tracking.md | 2 +-
.../userDocs/submarine-sdk/tracking.md | 2 +-
17 files changed, 205 insertions(+), 50 deletions(-)
diff --git a/submarine-sdk/pysubmarine/submarine/__init__.py b/submarine-sdk/pysubmarine/submarine/__init__.py
index 85519e8..0554922 100644
--- a/submarine-sdk/pysubmarine/submarine/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/__init__.py
@@ -20,12 +20,14 @@ from submarine.models.client import ModelsClient
log_param = submarine.tracking.fluent.log_param
log_metric = submarine.tracking.fluent.log_metric
+save_model = submarine.tracking.fluent.save_model
set_db_uri = utils.set_db_uri
get_db_uri = utils.get_db_uri
__all__ = [
"log_metric",
"log_param",
+ "save_model",
"set_db_uri",
"get_db_uri",
"ExperimentClient",
diff --git a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
index 3dff6fc..5a60648 100644
--- a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
+++ b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
@@ -50,7 +50,7 @@ class Repository:
key=dest_path,
)
- def log_artifacts(self, local_dir, artifact_path):
+ def log_artifacts(self, local_dir: str, artifact_path: str) -> str:
bucket = "submarine"
dest_path = self.dest_path
list_of_subfolder = self._list_artifact_subfolder(artifact_path)
@@ -71,3 +71,13 @@ class Repository:
bucket=bucket,
key=os.path.join(upload_path, f),
)
+ return f"s3://{bucket}/{dest_path}"
+
+ def delete_folder(self) -> None:
+ objects_to_delete = self.client.list_objects(Bucket="submarine", Prefix=self.dest_path)
+ if objects_to_delete.get("Contents") is not None:
+ delete_keys: dict = {"Objects": []}
+ delete_keys["Objects"] = [
+ {"Key": k} for k in [obj["Key"] for obj in objects_to_delete.get("Contents")]
+ ]
+ self.client.delete_objects(Bucket="submarine", Delete=delete_keys)
diff --git a/submarine-sdk/pysubmarine/submarine/models/client.py b/submarine-sdk/pysubmarine/submarine/models/client.py
index 9cf655f..e633188 100644
--- a/submarine-sdk/pysubmarine/submarine/models/client.py
+++ b/submarine-sdk/pysubmarine/submarine/models/client.py
@@ -15,16 +15,12 @@
under the License.
"""
import os
-import re
-import tempfile
import time
import mlflow
from mlflow.exceptions import MlflowException
from mlflow.tracking import MlflowClient
-from submarine.artifacts.repository import Repository
-
from .constant import (
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
@@ -58,7 +54,6 @@ class ModelsClient:
"tensorflow": mlflow.tensorflow.log_model,
"keras": mlflow.keras.log_model,
}
- self.artifact_repo = Repository(get_job_id())
def start(self):
"""
@@ -109,27 +104,6 @@ class ModelsClient:
else:
raise MlflowException("No valid type of model has been matched")
- def save_model_submarine(self, model_type, model, artifact_path, registered_model_name=None):
- pattern = r"[0-9A-Za-z][0-9A-Za-z-_]*[0-9A-Za-z]|[0-9A-Za-z]"
- if not re.fullmatch(pattern, artifact_path):
- raise Exception(
- "Artifact_path must only contains numbers, characters, hyphen and underscore. "
- " Artifact_path must starts and ends with numbers or characters."
- )
- with tempfile.TemporaryDirectory() as tempdir:
- if model_type == "pytorch":
- import submarine.models.pytorch
-
- submarine.models.pytorch.save_model(model, tempdir)
- elif model_type == "tensorflow":
- import submarine.models.tensorflow
-
- submarine.models.tensorflow.save_model(model, tempdir)
- else:
- raise Exception("No valid type of model has been matched to {}".format(model_type))
- self.artifact_repo.log_artifacts(tempdir, artifact_path)
- # TODO for registering model ()
-
def _get_or_create_experiment(self, experiment_name):
"""
Return the id of experiment.
diff --git a/submarine-sdk/pysubmarine/submarine/store/__init__.py b/submarine-sdk/pysubmarine/submarine/store/__init__.py
index 60412a9..458bc11 100644
--- a/submarine-sdk/pysubmarine/submarine/store/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/store/__init__.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-DEFAULT_SUBMARINE_JDBC_URL = "mysql+pymysql://submarine:password@localhost:3306/submarine"
+
+DEFAULT_SUBMARINE_JDBC_URL = "mysql+pymysql://submarine:password@submarine-database:3306/submarine"
__all__ = ["DEFAULT_SUBMARINE_JDBC_URL"]
diff --git a/submarine-sdk/pysubmarine/submarine/store/__init__.py b/submarine-sdk/pysubmarine/submarine/store/tracking/__init__.py
similarity index 85%
copy from submarine-sdk/pysubmarine/submarine/store/__init__.py
copy to submarine-sdk/pysubmarine/submarine/store/tracking/__init__.py
index 60412a9..a6eb1b5 100644
--- a/submarine-sdk/pysubmarine/submarine/store/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/store/tracking/__init__.py
@@ -12,7 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-DEFAULT_SUBMARINE_JDBC_URL = "mysql+pymysql://submarine:password@localhost:3306/submarine"
-
-__all__ = ["DEFAULT_SUBMARINE_JDBC_URL"]
diff --git a/submarine-sdk/pysubmarine/submarine/store/abstract_store.py b/submarine-sdk/pysubmarine/submarine/store/tracking/abstract_store.py
similarity index 100%
rename from submarine-sdk/pysubmarine/submarine/store/abstract_store.py
rename to submarine-sdk/pysubmarine/submarine/store/tracking/abstract_store.py
diff --git a/submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py b/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py
similarity index 96%
rename from submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py
rename to submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py
index 01adec5..23e8a8d 100644
--- a/submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/submarine/store/tracking/sqlalchemy_store.py
@@ -19,9 +19,10 @@ from contextlib import contextmanager
import sqlalchemy
+from submarine.entities import Param
from submarine.exceptions import SubmarineException
-from submarine.store.abstract_store import AbstractStore
from submarine.store.database.models import Base, SqlMetric, SqlParam
+from submarine.store.tracking.abstract_store import AbstractStore
from submarine.utils import extract_db_type_from_uri
_logger = logging.getLogger(__name__)
@@ -42,7 +43,7 @@ class SqlAlchemyStore(AbstractStore):
:py:class:`submarine.store.database.models.SqlParam`.
"""
- def __init__(self, db_uri):
+ def __init__(self, db_uri: str) -> None:
"""
Create a database backed store.
:param db_uri: The SQLAlchemy database URI string to connect to the database. See
@@ -151,7 +152,7 @@ class SqlAlchemyStore(AbstractStore):
except sqlalchemy.exc.IntegrityError:
session.rollback()
- def log_param(self, job_id, param):
+ def log_param(self, job_id: str, param: Param) -> None:
with self.ManagedSessionMaker() as session:
try:
self._get_or_create(
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/client.py b/submarine-sdk/pysubmarine/submarine/tracking/client.py
index 2ee9b09..982dfe8 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/client.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/client.py
@@ -12,30 +12,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
+import re
+import tempfile
import time
import submarine
+from submarine.artifacts.repository import Repository
from submarine.entities import Metric, Param
+from submarine.exceptions import SubmarineException
from submarine.tracking import utils
from submarine.utils.validation import validate_metric, validate_param
+from .constant import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, S3_ENDPOINT_URL
+
class SubmarineClient(object):
"""
Client of an submarine Tracking Server that creates and manages experiments and runs.
"""
- def __init__(self, db_uri=None):
+ def __init__(
+ self,
+ db_uri: str = None,
+ s3_registry_uri: str = None,
+ aws_access_key_id: str = None,
+ aws_secret_access_key: str = None,
+ ) -> None:
"""
:param db_uri: Address of local or remote tracking server. If not provided, defaults
to the service set by ``submarine.tracking.set_db_uri``. See
`Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>`_
for more info.
"""
+ os.environ["MLFLOW_S3_ENDPOINT_URL"] = s3_registry_uri or S3_ENDPOINT_URL
+ os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id or AWS_ACCESS_KEY_ID
+ os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key or AWS_SECRET_ACCESS_KEY
+ self.artifact_repo = Repository(utils.get_job_id())
self.db_uri = db_uri or submarine.get_db_uri()
- self.store = utils.get_sqlalchemy_store(self.db_uri)
+ self.store = utils.get_tracking_sqlalchemy_store(self.db_uri)
+ self.model_registry = utils.get_model_registry_sqlalchemy_store(self.db_uri)
- def log_metric(self, job_id, key, value, worker_index, timestamp=None, step=None):
+ def log_metric(
+ self,
+ job_id: str,
+ key: str,
+ value: float,
+ worker_index: str,
+ timestamp: int = None,
+ step: int = None,
+ ) -> None:
"""
Log a metric against the run ID.
:param job_id: The job name to which the metric should be logged.
@@ -53,7 +79,7 @@ class SubmarineClient(object):
metric = Metric(key, value, worker_index, timestamp, step)
self.store.log_metric(job_id, metric)
- def log_param(self, job_id, key, value, worker_index):
+ def log_param(self, job_id: str, key: str, value: str, worker_index: str) -> None:
"""
Log a parameter against the job name. Value is converted to a string.
:param job_id: The job name to which the parameter should be logged.
@@ -64,3 +90,46 @@ class SubmarineClient(object):
validate_param(key, value)
param = Param(key, str(value), worker_index)
self.store.log_param(job_id, param)
+
+ def save_model(
+ self, model_type: str, model, artifact_path: str, registered_model_name: str = None
+ ) -> None:
+ """
+ Save a model into the minio pod.
+ :param model_type: The type of the model.
+ :param model: Model.
+ :param artifact_path: Relative path of the artifact in the minio pod.
+ :param registered_model_name: If not None, register model into the model registry with
+ this name. If None, the model only be saved in minio pod.
+ """
+ pattern = r"[0-9A-Za-z][0-9A-Za-z-_]*[0-9A-Za-z]|[0-9A-Za-z]"
+ if not re.fullmatch(pattern, artifact_path):
+ raise Exception(
+ "Artifact_path must only contains numbers, characters, hyphen and underscore. "
+ " Artifact_path must starts and ends with numbers or characters."
+ )
+ with tempfile.TemporaryDirectory() as tempdir:
+ if model_type == "pytorch":
+ import submarine.models.pytorch
+
+ submarine.models.pytorch.save_model(model, tempdir)
+ elif model_type == "tensorflow":
+ import submarine.models.tensorflow
+
+ submarine.models.tensorflow.save_model(model, tempdir)
+ else:
+ raise Exception("No valid type of model has been matched to {}".format(model_type))
+ source = self.artifact_repo.log_artifacts(tempdir, artifact_path)
+
+ # Register model
+ if registered_model_name is not None:
+ try:
+ self.model_registry.get_registered_model(registered_model_name)
+ except SubmarineException:
+ self.model_registry.create_registered_model(name=registered_model_name)
+ self.model_registry.create_model_version(
+ name=registered_model_name,
+ source=source,
+ user_id="", # TODO(jeff-901): the user id is needed to be specified.
+ experiment_id=utils.get_job_id(),
+ )
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/constant.py b/submarine-sdk/pysubmarine/submarine/tracking/constant.py
new file mode 100644
index 0000000..201d89a
--- /dev/null
+++ b/submarine-sdk/pysubmarine/submarine/tracking/constant.py
@@ -0,0 +1,20 @@
+"""
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+"""
+
+S3_ENDPOINT_URL = "http://submarine-minio-service:9000"
+AWS_ACCESS_KEY_ID = "submarine_minio"
+AWS_SECRET_ACCESS_KEY = "submarine_minio"
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/fluent.py b/submarine-sdk/pysubmarine/submarine/tracking/fluent.py
index 8406ce7..aabe7ed 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/fluent.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/fluent.py
@@ -52,3 +52,15 @@ def log_metric(key, value, step=None):
job_id = get_job_id()
worker_index = get_worker_index()
SubmarineClient().log_metric(job_id, key, value, worker_index, datetime.now(), step or 0)
+
+
+def save_model(model_type: str, model, artifact_path: str, registered_model_name: str = None):
+ """
+ Save a model into the minio pod.
+ :param model_type: The type of the model.
+ :param model: Model.
+ :param artifact_path: Relative path of the artifact in the minio pod.
+ :param registered_model_name: If none None, register model into the model registry with
+ this name. If None, the model only be saved in minio pod.
+ """
+ SubmarineClient().save_model(model_type, model, artifact_path, registered_model_name)
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/utils.py b/submarine-sdk/pysubmarine/submarine/tracking/utils.py
index ec0ec14..4a223e1 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/utils.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/utils.py
@@ -19,7 +19,6 @@ import json
import os
import uuid
-from submarine.store.sqlalchemy_store import SqlAlchemyStore
from submarine.utils import env
_TRACKING_URI_ENV_VAR = "SUBMARINE_TRACKING_URI"
@@ -88,5 +87,13 @@ def get_worker_index():
return worker_index
-def get_sqlalchemy_store(store_uri):
+def get_tracking_sqlalchemy_store(store_uri: str):
+ from submarine.store.tracking.sqlalchemy_store import SqlAlchemyStore
+
+ return SqlAlchemyStore(store_uri)
+
+
+def get_model_registry_sqlalchemy_store(store_uri: str):
+ from submarine.store.model_registry.sqlalchemy_store import SqlAlchemyStore
+
return SqlAlchemyStore(store_uri)
diff --git a/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py b/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py
index dbdacdc..3104800 100644
--- a/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py
@@ -22,7 +22,7 @@ import submarine
from submarine.entities import Metric, Param
from submarine.store.database import models
from submarine.store.database.models import SqlExperiment, SqlMetric, SqlParam
-from submarine.store.sqlalchemy_store import SqlAlchemyStore
+from submarine.store.tracking.sqlalchemy_store import SqlAlchemyStore
JOB_ID = "application_123456789"
diff --git a/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py b/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
index 7410e16..59feefa 100644
--- a/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
+++ b/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
@@ -18,13 +18,18 @@ from datetime import datetime
from os import environ
import pytest
+import tensorflow
import submarine
+from submarine.artifacts.repository import Repository
from submarine.store.database import models
from submarine.store.database.models import SqlExperiment, SqlMetric, SqlParam
-from submarine.store.sqlalchemy_store import SqlAlchemyStore
+from submarine.tracking.client import SubmarineClient
+
+from .tf_model import LinearNNModel
JOB_ID = "application_123456789"
+MLFLOW_S3_ENDPOINT_URL = "http://localhost:9000"
@pytest.mark.e2e
@@ -35,7 +40,16 @@ class TestTracking(unittest.TestCase):
"mysql+pymysql://submarine_test:password_test@localhost:3306/submarine_test"
)
self.db_uri = submarine.get_db_uri()
+ self.client = SubmarineClient(
+ db_uri=self.db_uri,
+ s3_registry_uri=MLFLOW_S3_ENDPOINT_URL,
+ )
+ from submarine.store.tracking.sqlalchemy_store import SqlAlchemyStore
+
self.store = SqlAlchemyStore(self.db_uri)
+ from submarine.store.model_registry.sqlalchemy_store import SqlAlchemyStore
+
+ self.model_registry = SqlAlchemyStore(self.db_uri)
# TODO: use submarine.tracking.fluent to support experiment create
with self.store.ManagedSessionMaker() as session:
instance = SqlExperiment(
@@ -52,6 +66,10 @@ class TestTracking(unittest.TestCase):
def tearDown(self):
submarine.set_db_uri(None)
models.Base.metadata.drop_all(self.store.engine)
+ environ["MLFLOW_S3_ENDPOINT_URL"] = MLFLOW_S3_ENDPOINT_URL
+ environ["AWS_ACCESS_KEY_ID"] = "submarine_minio"
+ environ["AWS_SECRET_ACCESS_KEY"] = "submarine_minio"
+ Repository(JOB_ID).delete_folder()
def test_log_param(self):
submarine.log_param("name_1", "a")
@@ -73,3 +91,21 @@ class TestTracking(unittest.TestCase):
assert metrics[0].value == 5
assert metrics[0].id == JOB_ID
assert metrics[1].value == 6
+
+ @pytest.mark.skipif(tensorflow.version.VERSION < "2.0", reason="using tensorflow 2")
+ def test_save_model(self):
+ input_arr = tensorflow.random.uniform((1, 5))
+ model = LinearNNModel()
+ model(input_arr)
+ registered_model_name = "registerd_model_name"
+ self.client.save_model("tensorflow", model, "name_1", registered_model_name)
+ self.client.save_model("tensorflow", model, "name_2", registered_model_name)
+ # Validate model_versions
+ model_versions = self.model_registry.list_model_versions(registered_model_name)
+ assert len(model_versions) == 2
+ assert model_versions[0].name == registered_model_name
+ assert model_versions[0].version == 1
+ assert model_versions[0].source == f"s3://submarine/{JOB_ID}/name_1/1"
+ assert model_versions[1].name == registered_model_name
+ assert model_versions[1].version == 2
+ assert model_versions[1].source == f"s3://submarine/{JOB_ID}/name_2/1"
diff --git a/submarine-sdk/pysubmarine/tests/tracking/test_utils.py b/submarine-sdk/pysubmarine/tests/tracking/test_utils.py
index 2fc3392..121691d 100644
--- a/submarine-sdk/pysubmarine/tests/tracking/test_utils.py
+++ b/submarine-sdk/pysubmarine/tests/tracking/test_utils.py
@@ -18,12 +18,12 @@ import os
import mock
from submarine.store import DEFAULT_SUBMARINE_JDBC_URL
-from submarine.store.sqlalchemy_store import SqlAlchemyStore
+from submarine.store.tracking.sqlalchemy_store import SqlAlchemyStore
from submarine.tracking.utils import (
_JOB_ID_ENV_VAR,
_TRACKING_URI_ENV_VAR,
get_job_id,
- get_sqlalchemy_store,
+ get_tracking_sqlalchemy_store,
)
@@ -35,14 +35,14 @@ def test_get_job_id():
assert get_job_id() == "application_12346789"
-def test_get_sqlalchemy_store():
+def test_get_tracking_sqlalchemy_store():
patch_create_engine = mock.patch("sqlalchemy.create_engine")
uri = DEFAULT_SUBMARINE_JDBC_URL
env = {_TRACKING_URI_ENV_VAR: uri}
with mock.patch.dict(os.environ, env), patch_create_engine as mock_create_engine, mock.patch(
- "submarine.store.sqlalchemy_store.SqlAlchemyStore._initialize_tables"
+ "submarine.store.tracking.sqlalchemy_store.SqlAlchemyStore._initialize_tables"
):
- store = get_sqlalchemy_store(uri)
+ store = get_tracking_sqlalchemy_store(uri)
assert isinstance(store, SqlAlchemyStore)
assert store.db_uri == uri
mock_create_engine.assert_called_once_with(uri, pool_pre_ping=True)
diff --git a/submarine-sdk/pysubmarine/tests/tracking/tf_model.py b/submarine-sdk/pysubmarine/tests/tracking/tf_model.py
new file mode 100644
index 0000000..6b598b9
--- /dev/null
+++ b/submarine-sdk/pysubmarine/tests/tracking/tf_model.py
@@ -0,0 +1,27 @@
+"""
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+"""
+import tensorflow as tf
+
+
+class LinearNNModel(tf.keras.Model):
+ def __init__(self):
+ super(LinearNNModel, self).__init__()
+ self.dense1 = tf.keras.layers.Dense(1, activation=tf.nn.relu) # One in and one out
+
+ def call(self, x):
+ y_pred = self.dense1(x)
+ return y_pred
diff --git a/website/docs/userDocs/submarine-sdk/tracking.md b/website/docs/userDocs/submarine-sdk/tracking.md
index afacf39..f774753 100644
--- a/website/docs/userDocs/submarine-sdk/tracking.md
+++ b/website/docs/userDocs/submarine-sdk/tracking.md
@@ -44,7 +44,7 @@ set the tracking URI. You can also set the SUBMARINE_TRACKING_URI environment va
> **Parameters**
- **uri** \- Submarine record data to Mysql server. The database URL is expected in the format ``<dialect>+<driver>://<username>:<password>@<host>:<port>/<database>``.
- By default it's `mysql+pymysql://submarine:password@localhost:3306/submarine`.
+ By default it's `mysql+pymysql://submarine:password@submarine-database:3306/submarine`.
More detail : [SQLAlchemy docs](https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls)
### `submarine.log_param(key: str, value: str) -> None`
diff --git a/website/versioned_docs/version-0.6.0/userDocs/submarine-sdk/tracking.md b/website/versioned_docs/version-0.6.0/userDocs/submarine-sdk/tracking.md
index afacf39..f774753 100644
--- a/website/versioned_docs/version-0.6.0/userDocs/submarine-sdk/tracking.md
+++ b/website/versioned_docs/version-0.6.0/userDocs/submarine-sdk/tracking.md
@@ -44,7 +44,7 @@ set the tracking URI. You can also set the SUBMARINE_TRACKING_URI environment va
> **Parameters**
- **uri** \- Submarine record data to Mysql server. The database URL is expected in the format ``<dialect>+<driver>://<username>:<password>@<host>:<port>/<database>``.
- By default it's `mysql+pymysql://submarine:password@localhost:3306/submarine`.
+ By default it's `mysql+pymysql://submarine:password@submarine-database:3306/submarine`.
More detail : [SQLAlchemy docs](https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls)
### `submarine.log_param(key: str, value: str) -> None`
---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org