You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by pi...@apache.org on 2021/10/01 16:38:54 UTC
[submarine] branch master updated: SUBMARINE-1023. Submarine SDK
sqlalchemy store (model registry)
This is an automated email from the ASF dual-hosted git repository.
pingsutw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push:
new 4e4cfd9 SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)
4e4cfd9 is described below
commit 4e4cfd9162218a8806bac11728c22cd161f18f7b
Author: KUAN-HSUN-LI <b0...@ntu.edu.tw>
AuthorDate: Sun Sep 26 00:47:53 2021 +0800
SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)
### What is this PR for?
* Implement the model registry SQL method in Python SDK
* Apply sqlalchemy mypy checks
* Replace submarine tracking_uri with db_uri
### What type of PR is it?
[Feature]
### Todos
### What is the Jira issue?
https://issues.apache.org/jira/browse/SUBMARINE-1023
### How should this be tested?
All of the tests are provided in `submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py`
### Screenshots (if appropriate)
### Questions:
* Do the license files need updating? No
* Are there breaking changes for older versions? No
* Does this need new documentation? No
Author: KUAN-HSUN-LI <b0...@ntu.edu.tw>
Signed-off-by: Kevin <pi...@apache.org>
Closes #752 from KUAN-HSUN-LI/SUBMARINE-1023 and squashes the following commits:
77237889 [KUAN-HSUN-LI] SUBMARINE-1023. Submarine model registry Python SDK
5e086b57 [KUAN-HSUN-LI] fix
8c6574da [KUAN-HSUN-LI] fix
dd41039c [KUAN-HSUN-LI] replace tacking_uri with db_uri
ad9c5d7d [KUAN-HSUN-LI] add init file
374aa2dd [KUAN-HSUN-LI] SUBMARINE-1023. Submarine model registry Python SDK
b502fc17 [KUAN-HSUN-LI] SUBMARINE-1023. sqlalchemy store tests
343780f5 [KUAN-HSUN-LI] apply sqlalchmy mypy
---
dev-support/database/submarine-model.sql | 19 +-
.../style-check/python/mypy-requirements.txt | 1 +
pyproject.toml | 4 +-
submarine-sdk/pysubmarine/.style.yapf | 4 -
.../github-actions/test-requirements.txt | 2 +-
submarine-sdk/pysubmarine/setup.py | 2 +-
submarine-sdk/pysubmarine/submarine/__init__.py | 10 +-
.../submarine/entities/model_registry/__init__.py | 6 +-
.../{model_version_stages.py => model_stages.py} | 16 +-
.../entities/model_registry/model_version.py | 37 +-
.../{model_tag.py => model_version_tag.py} | 4 +-
.../entities/model_registry/registered_model.py | 9 +-
.../model_registry/registered_model_tag.py | 2 +-
.../pysubmarine/submarine/store/database/models.py | 216 +++---
.../model_registry/__init__.py} | 7 -
.../store/model_registry/abstract_store.py | 234 +++++++
.../store/model_registry/sqlalchemy_store.py | 578 ++++++++++++++++
.../pysubmarine/submarine/tracking/__init__.py | 9 +-
.../pysubmarine/submarine/tracking/client.py | 11 +-
.../pysubmarine/submarine/tracking/utils.py | 34 -
.../pysubmarine/submarine/utils/__init__.py | 7 +
.../submarine/utils/{__init__.py => db_utils.py} | 45 +-
.../pysubmarine/submarine/utils/validation.py | 43 +-
.../entities/model_registry/test_model_version.py | 47 +-
.../model_registry/test_registered_model.py | 3 +-
.../pysubmarine/tests/store/__init__.py | 6 -
.../tests/store/model_registry/__init__.py | 6 -
.../store/model_registry/test_sqlalchemy_store.py | 739 +++++++++++++++++++++
.../pysubmarine/tests/store/tracking/__init__.py | 6 -
.../store/{ => tracking}/test_sqlalchemy_store.py | 12 +-
.../pysubmarine/tests/tracking/test_tracking.py | 10 +-
.../pysubmarine/tests/tracking/test_utils.py | 18 -
.../__init__.py => tests/utils/test_db_utils.py} | 42 +-
33 files changed, 1873 insertions(+), 316 deletions(-)
diff --git a/dev-support/database/submarine-model.sql b/dev-support/database/submarine-model.sql
index da124e9..abad0bd 100644
--- a/dev-support/database/submarine-model.sql
+++ b/dev-support/database/submarine-model.sql
@@ -31,26 +31,27 @@ CREATE TABLE `registered_model_tag` (
DROP TABLE IF EXISTS `model_version`;
CREATE TABLE `model_version` (
- `name` VARCHAR(256) NOT NULL,
+ `name` VARCHAR(256) NOT NULL COMMENT 'Name of model',
`version` INTEGER NOT NULL,
+ `source` VARCHAR(512) NOT NULL COMMENT 'Model saved link',
`user_id` VARCHAR(64) NOT NULL COMMENT 'Id of the created user',
`experiment_id` VARCHAR(64) NOT NULL,
- `current_stage` VARCHAR(20) COMMENT 'Model stage ex: None, production...',
+ `current_stage` VARCHAR(64) COMMENT 'Model stage ex: None, production...',
`creation_time` DATETIME(3) COMMENT 'Millisecond precision',
`last_updated_time` DATETIME(3) COMMENT 'Millisecond precision',
- `source` VARCHAR(512) COMMENT 'Model saved link',
`dataset` VARCHAR(256) COMMENT 'Which dataset is used',
`description` VARCHAR(5000),
CONSTRAINT `model_version_pk` PRIMARY KEY (`name`, `version`),
- FOREIGN KEY(`name`) REFERENCES `registered_model` (`name`) ON UPDATE CASCADE ON DELETE CASCADE
+ FOREIGN KEY(`name`) REFERENCES `registered_model` (`name`) ON UPDATE CASCADE ON DELETE CASCADE,
+ UNIQUE(`source`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
-DROP TABLE IF EXISTS `model_tag`;
-CREATE TABLE `model_tag` (
- `name` VARCHAR(256) NOT NULL,
+DROP TABLE IF EXISTS `model_version_tag`;
+CREATE TABLE `model_version_tag` (
+ `name` VARCHAR(256) NOT NULL COMMENT 'Name of model',
`version` INTEGER NOT NULL,
`tag` VARCHAR(256) NOT NULL,
- CONSTRAINT `model_tag_pk` PRIMARY KEY (`name`, `version`, `tag`),
+ CONSTRAINT `model_version_tag_pk` PRIMARY KEY (`name`, `version`, `tag`),
FOREIGN KEY(`name`, `version`) REFERENCES `model_version` (`name`, `version`) ON UPDATE CASCADE ON DELETE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
@@ -72,7 +73,7 @@ DROP TABLE IF EXISTS `param`;
CREATE TABLE `param` (
`id` VARCHAR(64) NOT NULL COMMENT 'Id of the Experiment',
`key` VARCHAR(190) NOT NULL COMMENT '`String` (limit 190 characters). Part of *Primary Key* for ``param`` table.',
- `value` VARCHAR(32) NOT NULL COMMENT '`String` (limit 190 characters). Defined as *Non-null* in schema.',
+ `value` VARCHAR(190) NOT NULL COMMENT '`String` (limit 190 characters). Defined as *Non-null* in schema.',
`worker_index` VARCHAR(32) NOT NULL COMMENT '`String` (limit 32 characters). Part of *Primary Key* for\r\n ``metric`` table.',
CONSTRAINT `param_pk` PRIMARY KEY (`id`, `key`, `worker_index`),
FOREIGN KEY(`id`) REFERENCES `experiment` (`id`) ON UPDATE CASCADE ON DELETE CASCADE
diff --git a/dev-support/style-check/python/mypy-requirements.txt b/dev-support/style-check/python/mypy-requirements.txt
index 6a581e3..8f1e19b 100644
--- a/dev-support/style-check/python/mypy-requirements.txt
+++ b/dev-support/style-check/python/mypy-requirements.txt
@@ -18,3 +18,4 @@ types-requests==2.25.6
types-certifi==2020.4.0
types-six==1.16.1
types-python-dateutil==2.8.0
+sqlalchemy[mypy]
diff --git a/pyproject.toml b/pyproject.toml
index e370133..cdf4ebd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,4 +20,6 @@ profile = "black"
line_length = 100
[tool.black]
max-line-length = 100
-line-length = 100
\ No newline at end of file
+line-length = 100
+[tool.mypy]
+plugins = "sqlalchemy.ext.mypy.plugin"
\ No newline at end of file
diff --git a/submarine-sdk/pysubmarine/.style.yapf b/submarine-sdk/pysubmarine/.style.yapf
deleted file mode 100644
index 34e7202..0000000
--- a/submarine-sdk/pysubmarine/.style.yapf
+++ /dev/null
@@ -1,4 +0,0 @@
-[style]
-based_on_style = google
-indent_width: 4
-continuation_indent_width: 4
diff --git a/submarine-sdk/pysubmarine/github-actions/test-requirements.txt b/submarine-sdk/pysubmarine/github-actions/test-requirements.txt
index cc618b4..54a363b 100644
--- a/submarine-sdk/pysubmarine/github-actions/test-requirements.txt
+++ b/submarine-sdk/pysubmarine/github-actions/test-requirements.txt
@@ -23,7 +23,6 @@ pytest==3.2.1
pytest-cov==2.6.0
pytest-localserver==0.5.0
pylint==2.5.2
-sqlalchemy==1.3.0
PyMySQL==0.9.3
pytest-mock==1.13.0
certifi >= 14.05.14
@@ -31,3 +30,4 @@ six >= 1.10
python_dateutil >= 2.5.3
setuptools >= 21.0.0
urllib3 >= 1.15.1
+freezegun==1.1.0
diff --git a/submarine-sdk/pysubmarine/setup.py b/submarine-sdk/pysubmarine/setup.py
index d10855b..6cf2aed 100644
--- a/submarine-sdk/pysubmarine/setup.py
+++ b/submarine-sdk/pysubmarine/setup.py
@@ -30,7 +30,7 @@ setup(
"six>=1.10.0",
"numpy==1.18.5",
"pandas",
- "sqlalchemy",
+ "sqlalchemy>=1.4.0",
"sqlparse",
"pymysql",
"requests==2.26.0",
diff --git a/submarine-sdk/pysubmarine/submarine/__init__.py b/submarine-sdk/pysubmarine/submarine/__init__.py
index 979e996..85519e8 100644
--- a/submarine-sdk/pysubmarine/submarine/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/__init__.py
@@ -13,21 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import submarine.tracking as tracking
import submarine.tracking.fluent
+import submarine.utils as utils
from submarine.experiment.api.experiment_client import ExperimentClient
from submarine.models.client import ModelsClient
log_param = submarine.tracking.fluent.log_param
log_metric = submarine.tracking.fluent.log_metric
-set_tracking_uri = tracking.set_tracking_uri
-get_tracking_uri = tracking.get_tracking_uri
+set_db_uri = utils.set_db_uri
+get_db_uri = utils.get_db_uri
__all__ = [
"log_metric",
"log_param",
- "set_tracking_uri",
- "get_tracking_uri",
+ "set_db_uri",
+ "get_db_uri",
"ExperimentClient",
"ModelsClient",
]
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/__init__.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/__init__.py
index 7f537e8..d88705a 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/__init__.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from submarine.entities.model_registry.model_tag import ModelTag
from submarine.entities.model_registry.model_version import ModelVersion
+from submarine.entities.model_registry.model_version_tag import ModelVersionTag
from submarine.entities.model_registry.registered_model import RegisteredModel
from submarine.entities.model_registry.registered_model_tag import RegisteredModelTag
__all__ = [
- "ModelVersion",
- "ModelTag",
"RegisteredModel",
"RegisteredModelTag",
+ "ModelVersion",
+ "ModelVersionTag",
]
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version_stages.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py
similarity index 64%
copy from submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version_stages.py
copy to submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py
index a8fae8e..4a5e565 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version_stages.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_stages.py
@@ -13,9 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from submarine.exceptions import SubmarineException
+
STAGE_NONE = "None"
-STAGE_STAGING = "Staging"
+STAGE_DEVELOPING = "Developing"
STAGE_PRODUCTION = "Production"
STAGE_ARCHIVED = "Archived"
-ALL_STAGES = [STAGE_NONE, STAGE_STAGING, STAGE_PRODUCTION, STAGE_ARCHIVED]
+STAGE_DELETED_INTERNAL = "Deleted_Internal"
+
+ALL_STAGES = [STAGE_NONE, STAGE_DEVELOPING, STAGE_PRODUCTION, STAGE_ARCHIVED]
+_CANONICAL_MAPPING = {stage.lower(): stage for stage in ALL_STAGES}
+
+
+def get_canonical_stage(stage):
+ key = stage.lower()
+ if key not in _CANONICAL_MAPPING:
+ raise SubmarineException(f"Invalid Model Version stage {stage}.")
+ return _CANONICAL_MAPPING[key]
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
index 88c9ff4..86652b6 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
@@ -18,78 +18,78 @@ from submarine.entities._submarine_object import _SubmarineObject
class ModelVersion(_SubmarineObject):
"""
- Model Version object.
+ Model version object.
"""
def __init__(
self,
name,
version,
+ source,
user_id,
experiment_id,
current_stage,
creation_time,
last_updated_time,
- source,
dataset=None,
description=None,
tags=None,
):
self._name = name
self._version = version
+ self._source = source
self._user_id = user_id
self._experiment_id = experiment_id
self._current_stage = current_stage
self._creation_time = creation_time
self._last_updated_time = last_updated_time
- self._source = source
self._dataset = dataset
self._description = description
self._tags = [tag.tag for tag in (tags or [])]
@property
def name(self):
- """String. Unique name within Model Registry."""
+ """String. Registered model name"""
return self._name
@property
def version(self):
- """String. version"""
+ """Integer. version"""
return self._version
@property
+ def source(self):
+ """String. Source path for the model."""
+ return self._source
+
+ @property
def user_id(self):
- """String. User ID that created this model version."""
+ """String. User ID that created this version."""
return self._user_id
@property
def experiment_id(self):
- """String. Experiment ID that created this model version."""
+ """String. Experiment ID that created this version."""
return self._experiment_id
@property
def creation_time(self):
- """Datetime object. Model version creation timestamp."""
+ """Datetime object. The creation datetime of this version."""
return self._creation_time
@property
def last_updated_time(self):
- """Datetime object. Timestamp of last update for this model version."""
+ """Datetime object. Datetime of last update for this version."""
return self._last_updated_time
@property
- def source(self):
- """String. Source path for the model."""
- return self._source
-
- @property
def current_stage(self):
- """String. Current stage of this model version."""
+ """String. Current stage of this version."""
return self._current_stage
@property
def dataset(self):
- """String. Dataset used by this model version"""
+ """String. Dataset used for this version."""
return self._dataset
@property
@@ -99,8 +99,5 @@ class ModelVersion(_SubmarineObject):
@property
def tags(self):
- """List of strings"""
+ """List of strings."""
return self._tags
-
- def _add_tag(self, tag):
- self._tags.append(tag)
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_tag.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version_tag.py
similarity index 93%
rename from submarine-sdk/pysubmarine/submarine/entities/model_registry/model_tag.py
rename to submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version_tag.py
index 5cbcae1..d30ea26 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_tag.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version_tag.py
@@ -16,7 +16,7 @@
from submarine.entities._submarine_object import _SubmarineObject
-class ModelTag(_SubmarineObject):
+class ModelVersionTag(_SubmarineObject):
"""
Tag object associated with a model version.
"""
@@ -26,5 +26,5 @@ class ModelTag(_SubmarineObject):
@property
def tag(self):
- """String tag"""
+ """String tag."""
return self._tag
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py
index 7ff5e69..b88ac22 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model.py
@@ -18,7 +18,7 @@ from submarine.entities._submarine_object import _SubmarineObject
class RegisteredModel(_SubmarineObject):
"""
- Registered Model object.
+ Registered model object.
"""
def __init__(self, name, creation_time, last_updated_time, description=None, tags=None):
@@ -35,12 +35,12 @@ class RegisteredModel(_SubmarineObject):
@property
def creation_time(self):
- """Datetime object. Model version creation timestamp."""
+ """Datetime object. Registered model creation datetime."""
return self._creation_time
@property
def last_updated_time(self):
- """Datetime object. Timestamp of last update for this model version."""
+ """Datetime object. Datetime of last update for this model."""
return self._last_updated_time
@property
@@ -52,6 +52,3 @@ class RegisteredModel(_SubmarineObject):
def tags(self):
"""List of strings"""
return self._tags
-
- def _add_tag(self, tag):
- self._tags.append(tag)
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model_tag.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model_tag.py
index 22de3c9..70d04ef 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model_tag.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/registered_model_tag.py
@@ -26,5 +26,5 @@ class RegisteredModelTag(_SubmarineObject):
@property
def tag(self):
- """String tag."""
+ """String tag"""
return self._tag
diff --git a/submarine-sdk/pysubmarine/submarine/store/database/models.py b/submarine-sdk/pysubmarine/submarine/store/database/models.py
index a1caffe..ff55c2a 100644
--- a/submarine-sdk/pysubmarine/submarine/store/database/models.py
+++ b/submarine-sdk/pysubmarine/submarine/store/database/models.py
@@ -14,7 +14,7 @@
# limitations under the License.
from datetime import datetime
-from typing import Any
+from typing import List
import sqlalchemy as sa
from sqlalchemy import (
@@ -30,56 +30,71 @@ from sqlalchemy import (
)
from sqlalchemy.dialects.mysql import DATETIME
from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.orm import backref, relationship
+from sqlalchemy.orm import Mapped, relationship
from submarine.entities import Experiment, Metric, Param
from submarine.entities.model_registry import (
- ModelTag,
ModelVersion,
+ ModelVersionTag,
RegisteredModel,
RegisteredModelTag,
)
-from submarine.entities.model_registry.model_version_stages import STAGE_NONE
+from submarine.entities.model_registry.model_stages import STAGE_NONE
# Base class in sqlalchemy is a dynamic type
-Base: Any = declarative_base()
+Base = declarative_base()
-# +---------------------+-------------------------+-------------------------+-------------+
-# | name | creation_time | last_updated_time | description |
-# +---------------------+-------------------------+-------------------------+-------------+
-# | image_classfication | 2021-08-31 11:11:11.111 | 2021-09-02 11:11:11.111 | ... |
-# | speech_recoginition | 2021-08-31 16:16:16.166 | 2021-08-31 20:20:20.200 | ... |
-# +---------------------+-------------------------+-------------------------+-------------+
+# +----------+-------------------------+-------------------------+-------------+
+# | name | creation_time | last_updated_time | description |
+# +----------+-------------------------+-------------------------+-------------+
+# | ResNet50 | 2021-08-31 11:11:11.111 | 2021-09-02 11:11:11.111 | ... |
+# | BERT | 2021-08-31 16:16:16.166 | 2021-08-31 20:20:20.200 | ... |
+# +----------+-------------------------+-------------------------+-------------+
class SqlRegisteredModel(Base):
__tablename__ = "registered_model"
- name = Column(String(256), unique=True, nullable=False)
+ name = Column(String(256), unique=True)
"""
- Name for registered models: Part of *Primary Key* for ``registered_model`` table.
+ Name of registered model: Part of *Primary Key* for ``registered_model`` table.
"""
creation_time = Column(DATETIME(fsp=3), default=datetime.now())
"""
- Creation time of registered models: default current time in milliseconds
+ Creation time of registered model: default current time in milliseconds.
"""
- last_updated_time = Column(DATETIME(fsp=3), nullable=True, default=None)
+ last_updated_time = Column(DATETIME(fsp=3), nullable=True)
"""
- Last updated time of registered model
+ Last updated time of registered model.
"""
- description = Column(String(5000), nullable=True, default="")
+ description = Column(String(5000), nullable=True, default=None)
"""
- Description for registered model
+ Description for registered model.
"""
- __table_args__ = (PrimaryKeyConstraint("name", name="registered_model_pk"),)
+ tags: Mapped[List["SqlRegisteredModelTag"]] = relationship(
+ "SqlRegisteredModelTag", back_populates="registered_model", cascade="all"
+ )
+ """
+ Registered model Tags reference to SqlRegisteredModelTag.
+ """
+
+ model_versions: Mapped[List["SqlModelVersion"]] = relationship(
+ "SqlModelVersion", back_populates="registered_model", cascade="all"
+ )
+ """
+ Metadatas reference to SqlRegisteredModel
+ """
+
+ __table_args__ = (PrimaryKeyConstraint("name", name="model_pk"),)
def __repr__(self):
- return "<SqlRegisteredModel ({}, {}, {}, {})>".format(
- self.name, self.creation_time, self.last_updated_time, self.description
+ return (
+ f"<SqlRegisteredModel ({self.name}, {self.creation_time}, {self.last_updated_time},"
+ f" {self.description})>"
)
def to_submarine_entity(self):
@@ -92,17 +107,17 @@ class SqlRegisteredModel(Base):
creation_time=self.creation_time,
last_updated_time=self.last_updated_time,
description=self.description,
- tags=[tag.to_submarine_entity for tag in self.registered_model_tag],
+ tags=[tag.to_submarine_entity() for tag in self.tags],
)
-# +---------------------+-------+
-# | name | tag |
-# +---------------------+-------+
-# | image_classfication | image |
-# | image_classfication | major |
-# | speech_recoginition | audio |
-# +---------------------+-------+
+# +----------+-----------+
+# | name | tag |
+# +----------+-----------+
+# | ResNet50 | image |
+# | ResNet50 | marketing |
+# | BERT | text |
+# +----------+-----------+
class SqlRegisteredModelTag(Base):
@@ -112,25 +127,23 @@ class SqlRegisteredModelTag(Base):
String(256), ForeignKey("registered_model.name", onupdate="cascade", ondelete="cascade")
)
"""
- Name for registered models: Part of *Primary Key* for ``registered_model_tag`` table. Refer to
- name of ``registered_model`` table.
+ Name of registered model: Part of *Primary Key* for ``registered_model_tag`` table.
+ Refer to name of ``registered_model`` table.
"""
tag = Column(String(256), nullable=False)
"""
- Registered model tag: `String` (limit 256 characters). Part of *Primary Key* for
- ``registered_model_tag`` table.
+ Registered model tag: `String` (limit 256 characters).
+ Part of *Primary Key* for ``registered_model_tag`` table.
"""
# linked entities
- registered_model = relationship(
- "SqlRegisteredModel", backref=backref("registered_model_tag", cascade="all")
- )
+ registered_model: SqlRegisteredModel = relationship("SqlRegisteredModel", back_populates="tags")
__table_args__ = (PrimaryKeyConstraint("name", "tag", name="registered_model_tag_pk"),)
def __repr__(self):
- return "<SqlRegisteredModelTag ({}, {})>".format(self.name, self.tag)
+ return f"<SqlRegisteredModelTag ({self.name}, {self.tag})>"
# entity mappers
def to_submarine_entity(self):
@@ -141,150 +154,153 @@ class SqlRegisteredModelTag(Base):
return RegisteredModelTag(self.tag)
-# +---------------------+---------+-----+-------------------------------+-----+
-# | name | version | ... | source | ... |
-# +---------------------+---------+-----+-------------------------------+-----+
-# | image_classfication | 1 | ... | s3://submarine/ResNet50/1/ | ... |
-# | image_classfication | 2 | ... | s3://submarine/DenseNet121/2/ | ... |
-# | speech_recoginition | 1 | ... | s3://submarine/ASR/1/ | ... |
-# +---------------------+---------+-----+-------------------------------+-----+
+# +----------+---------+-------------------------------+-----+
+# | name | version | source | ... |
+# +----------+---------+-------------------------------+-----+
+# | ResNet50 | 1 | s3://submarine/ResNet50/1/ | ... |
+# | ResNet50 | 2 | s3://submarine/ResNet50/2/ | ... |
+# | BERT | 1 | s3://submarine/BERT/1/ | ... |
+# +----------+---------+-------------------------------+-----+
class SqlModelVersion(Base):
__tablename__ = "model_version"
name = Column(
- String(256), ForeignKey("registered_model.name", onupdate="cascade", ondelete="cascade")
+ String(256),
+ ForeignKey("registered_model.name", onupdate="cascade", ondelete="cascade"),
+ nullable=False,
)
"""
- Name for registered models: Part of *Primary Key* for ``registered_model_tag`` table. Refer to
- name of ``registered_model`` table.
+ Name of model version: Part of *Primary Key* for ``model_version`` table.
"""
version = Column(Integer, nullable=False)
"""
- Model version: Part of *Primary Key* for ``registered_model_tag`` table.
+ Version of registered model: Part of *Primary Key* for ``model_version`` table.
+ """
+
+ source = Column(String(512), nullable=False, unique=True)
+ """
+ Source of model: Part of *Primary Key* for ``model_version`` table.
+ database link refer to this version of model.
"""
user_id = Column(String(64), nullable=False)
"""
- ID to whom this model is created
+ ID to whom this model is created.
"""
experiment_id = Column(String(64), nullable=False)
"""
- ID to which this model belongs to
+ ID to which this version of model belongs to.
"""
- current_stage = Column(String(20), default=STAGE_NONE)
+ current_stage = Column(String(64), default=STAGE_NONE)
"""
- Current stage of this model: it can be `None`, `Staging`, `Production` and `Achieved`
+ Current stage of this version of model: it can be `None`, `Developing`,
+ `Production` and `Achieved`
"""
creation_time = Column(DATETIME(fsp=3), default=datetime.now())
"""
- Creation time of this model version: default current time in milliseconds
+ Creation time of this version of model: default current time in milliseconds
"""
- last_updated_time = Column(DATETIME(fsp=3), nullable=True, default=None)
+ last_updated_time = Column(DATETIME(fsp=3), nullable=True)
"""
- Last updated time of this model version
+ Last updated time of this version of model.
"""
- source = Column(String(512), nullable=True, default=None)
+ dataset = Column(String(256), nullable=True, default=None)
"""
- Source of model: database link refer to this model
+ Dataset used for this version of model.
"""
- dataset = Column(String(256), nullable=True, default=None)
+ description = Column(String(5000), nullable=True)
"""
- Dataset used for this model.
+ Description for this version of model.
"""
- description = Column(String(5000), nullable=True)
+ tags: Mapped[List["SqlModelVersionTag"]] = relationship(
+ "SqlModelVersionTag", back_populates="model_version", cascade="all"
+ )
"""
- Description for model version.
+ Model version tags reference to SqlModelVersionTag.
"""
# linked entities
- registered_model = relationship(
- "SqlRegisteredModel", backref=backref("model_version", cascade="all")
+ registered_model: SqlRegisteredModel = relationship(
+ "SqlRegisteredModel", back_populates="model_versions"
)
- __table_args__ = (PrimaryKeyConstraint("name", "version", name="model_version_pk"),)
+ __table_args__ = (PrimaryKeyConstraint("name", "version", "source", name="model_version_pk"),)
def __repr__(self):
- return "<SqlModelVersion ({}, {}, {}, {}, {}, {}, {}, {}, {}, {})>".format(
- self.name,
- self.version,
- self.user_id,
- self.experiment_id,
- self.current_stage,
- self.creation_time,
- self.last_updated_time,
- self.source,
- self.dataset,
- self.description,
+ return (
+ f"<SqlModelMetadata ({self.name}, {self.version}, {self.source}, {self.user_id},"
+ f" {self.experiment_id}, {self.current_stage}, {self.creation_time},"
+ f" {self.last_updated_time}, {self.dataset}, {self.description})>"
)
def to_submarine_entity(self):
"""
Convert DB model to corresponding Submarine entity.
- :return: :py:class:`submarine.entities.RegisteredModel`.
+ :return: :py:class:`submarine.entities.ModelMetadata`.
"""
return ModelVersion(
name=self.name,
version=self.version,
+ source=self.source,
user_id=self.user_id,
experiment_id=self.experiment_id,
current_stage=self.current_stage,
creation_time=self.creation_time,
last_updated_time=self.last_updated_time,
- source=self.source,
dataset=self.dataset,
description=self.description,
- tags=[tag.to_submarine_entity for tag in self.model_tag],
+ tags=[tag.to_submarine_entity() for tag in self.tags],
)
-# +---------------------+---------+-----------------+
-# | name | version | tag |
-# +---------------------+---------+-----------------+
-# | image_classfication | 1 | best |
-# | image_classfication | 1 | anomaly_support |
-# | image_classfication | 2 | testing |
-# | speech_recoginition | 1 | best |
-# +---------------------+---------+-----------------+
+# +----------+---------+----------+
+# | name | version | tag |
+# +----------+---------+----------+
+# | ResNet50 | 1 | best |
+# | ResNet50 | 1 | serving |
+# | ResNet50 | 2 | new |
+# | BERT | 1 | testing |
+# +----------+---------+----------+
-class SqlModelTag(Base):
- __tablename__ = "model_tag"
+class SqlModelVersionTag(Base):
+ __tablename__ = "model_version_tag"
name = Column(String(256), nullable=False)
"""
- Name for registered models: Part of *Foreign Key* for ``model_tag`` table. Refer to
- name of ``model_version`` table.
+ Name of registered model: Part of *Foreign Key* for ``model_version_tag`` table.
+ Refer to name of ``model_metadata`` table.
"""
version = Column(Integer, nullable=False)
"""
- version of model: Part of *Foreign Key* for ``model_tag`` table. Refer to
- version of ``model_version`` table.
+ version of model: Part of *Foreign Key* for ``model_version_tag`` table.
+ Refer to version of ``model_metadata`` table.
"""
tag = Column(String(256), nullable=False)
"""
- tag of model version: `String` (limit 256 characters). Part of *Primary Key* for
- ``model_tag`` table.
+ tag of model version: `String` (limit 256 characters).
+ Part of *Primary Key* for ``model_tag`` table.
"""
# linked entities
- model_version = relationship(
- "SqlModelVersion", foreign_keys=[name, version], backref=backref("model_tag", cascade="all")
+ model_version: SqlModelVersion = relationship(
+ "SqlModelVersion", foreign_keys=[name, version], back_populates="tags"
)
__table_args__ = (
- PrimaryKeyConstraint("name", "tag", name="model_tag_pk"),
+ PrimaryKeyConstraint("name", "version", "tag", name="model_version_tag_pk"),
ForeignKeyConstraint(
("name", "version"),
("model_version.name", "model_version.version"),
@@ -294,15 +310,15 @@ class SqlModelTag(Base):
)
def __repr__(self):
- return "<SqlRegisteredModelTag ({}, {}, {})>".format(self.name, self.version, self.tag)
+ return f"<SqlModelVersionTag ({self.name}, {self.version}, {self.tag})>"
# entity mappers
def to_submarine_entity(self):
"""
Convert DB model to corresponding submarine entity.
- :return: :py:class:`submarine.entities.ModelTag`.
+ :return: :py:class:`submarine.entities.ModelVersionTag`.
"""
- return ModelTag(self.tag)
+ return ModelVersionTag(self.tag)
# +--------------------+-----------------+-----------+-------------------------+-----+
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version_stages.py b/submarine-sdk/pysubmarine/submarine/store/model_registry/__init__.py
similarity index 80%
rename from submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version_stages.py
rename to submarine-sdk/pysubmarine/submarine/store/model_registry/__init__.py
index a8fae8e..a6eb1b5 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version_stages.py
+++ b/submarine-sdk/pysubmarine/submarine/store/model_registry/__init__.py
@@ -12,10 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-STAGE_NONE = "None"
-STAGE_STAGING = "Staging"
-STAGE_PRODUCTION = "Production"
-STAGE_ARCHIVED = "Archived"
-
-ALL_STAGES = [STAGE_NONE, STAGE_STAGING, STAGE_PRODUCTION, STAGE_ARCHIVED]
diff --git a/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py b/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
new file mode 100644
index 0000000..a5e81ce
--- /dev/null
+++ b/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
@@ -0,0 +1,234 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABCMeta, abstractmethod
+from typing import List
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+
+
+class AbstractStore:
+ """
+ Abstract class for Backend model registry
+ This class defines the API interface for frontends to connect with various types of backends.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self) -> None:
+ """
+ Empty constructor for now. This is deliberately not marked as abstract, else every
+ derived class would be forced to create one.
+ """
+ pass
+
+ @abstractmethod
+ def create_registered_model(
+ self, name: str, description: str = None, tags: List[str] = None
+ ) -> RegisteredModel:
+ """
+ Create a new registered model in backend store.
+ :param name: Name of the new registered model.
+ This is expected to be unique in the backend store.
+ :param description: Description of the registered model.
+ :param tags: A list of tags associated with this registered model.
+ :return: A single object of :py:class:`submarine.entities.model_registry.RegisteredModel`
+ created in the backend.
+ """
+ pass
+
+ @abstractmethod
+ def update_registered_model_description(self, name: str, description: str) -> RegisteredModel:
+ """
+ Update description of the registered model.
+ :param name: Registered model name.
+ :param description: New description.
+ :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+ object.
+ """
+ pass
+
+ @abstractmethod
+ def rename_registered_model(self, name: str, new_name: str) -> RegisteredModel:
+ """
+ Rename the registered model.
+ :param name: Registered model name.
+ :param new_name: New proposed name.
+ :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+ object.
+ """
+ pass
+
+ @abstractmethod
+ def delete_registered_model(self, name: str) -> None:
+ """
+ Delete the registered model.
+ :param name: Registered model name.
+ :return: None.
+ """
+ pass
+
+ @abstractmethod
+ def list_registered_model(
+ self, filter_str: str = None, filter_tags: List[str] = None
+ ) -> List[RegisteredModel]:
+ """
+ List of all models.
+ :param filter_string: Filter query string, defaults to searching all registered models.
+ :param filter_tags: Filter tags, defaults not to filter any tags.
+ :return: A List of :py:class:`submarine.entities.model_registry.RegisteredModel` objects
+ that satisfy the search expressions.
+ """
+ pass
+
+ @abstractmethod
+ def get_registered_model(self, name: str) -> RegisteredModel:
+ """
+ Get registered model instance by name.
+ :param name: Registered model name.
+ :return: A single :py:class:`submarine.entities.model_registry.RegisteredModel` object.
+ """
+ pass
+
+ @abstractmethod
+ def add_registered_model_tag(self, name: str, tag: str) -> None:
+ """
+ Add a tag for the registered model.
+ :param name: registered model name.
+ :param tag: String of tag value.
+ :return: None.
+ """
+ pass
+
+ @abstractmethod
+ def delete_registered_model_tag(self, name: str, tag: str) -> None:
+ """
+ Delete a tag associated with the registered model.
+ :param name: Model name.
+ :param tag: String of tag value.
+ :return: None.
+ """
+ pass
+
+ @abstractmethod
+ def create_model_version(
+ self,
+ name: str,
+ source: str,
+ user_id: str,
+ experiment_id: str,
+ dataset: str = None,
+ description: str = None,
+ tags: List[str] = None,
+ ) -> ModelVersion:
+ """
+ Create a new version of the registered model
+ :param name: Registered model name.
+ :param source: Source path where this version of model is stored.
+ :param user_id: User ID from server that created this model
+ :param experiment_id: Experiment ID which this model is created.
+ :param dataset: Dataset which this version of model is used.
+ :param description: Description of this version.
+ :param tags: A list of string associated with this version of model.
+ :return: A single object of :py:class:`submarine.entities.model_registry.ModelMetadata`
+ created in the backend.
+ """
+ pass
+
+ @abstractmethod
+ def update_model_version_description(
+ self, name: str, version: int, description: str
+ ) -> ModelVersion:
+ """
+ Update description associated with the version of model in backend.
+ :param name: Registered model name.
+ :param version: Version of the registered model.
+ :param description: New model description.
+ :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+ """
+ pass
+
+ @abstractmethod
+ def transition_model_version_stage(self, name: str, version: int, stage: str) -> ModelVersion:
+ """
+ Update this version's stage.
+ :param name: Registered model name.
+ :param version: Version of the registered model.
+ :param stage: New desired stage for this version of registered model.
+ :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+ """
+
+ @abstractmethod
+ def delete_model_version(self, name: str, version: int) -> None:
+ """
+ Delete model version in backend.
+ :param name: Registered model name.
+ :param version: Version of the registered model.
+ :return: None
+ """
+ pass
+
+ @abstractmethod
+ def get_model_version(self, name: str, version: int) -> ModelVersion:
+ """
+ Get the model by name and version.
+ :param name: Registered model name.
+ :param version: Version of registered model.
+ :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+ """
+ pass
+
+ @abstractmethod
+ def list_model_versions(self, name: str, filter_tags: list = None) -> List[ModelVersion]:
+ """
+ List of all models that satisfy the filter criteria.
+ :param name: Registered model name.
+ :param filter_tags: Filter tags, defaults not to filter any tags.
+ :return: A List of :py:class:`submarine.entities.model_registry.ModelVersion` objects
+ that satisfy the search expressions.
+ """
+ pass
+
+ @abstractmethod
+ def get_model_version_uri(self, name: str, version: int) -> str:
+ """
+ Get the location in Model registry for this version.
+ :param name: Registered model name.
+ :param version: Version of registered model.
+ :return: A single URI location.
+ """
+ pass
+
+ @abstractmethod
+ def add_model_version_tag(self, name: str, version: int, tag: str) -> None:
+ """
+ Add a tag for this version of model.
+ :param name: Registered model name.
+ :param version: Version of registered model.
+ :param tag: String of tag value.
+ :return: None.
+ """
+ pass
+
+ @abstractmethod
+ def delete_model_version_tag(self, name: str, version: int, tag: str) -> None:
+ """
+ Delete a tag associated with this version of model.
+ :param name: Registered model name.
+ :param version: Version of registered model.
+ :param tag: String of tag value.
+ :return: None.
+ """
+ pass
diff --git a/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py b/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
new file mode 100644
index 0000000..47fb075
--- /dev/null
+++ b/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
@@ -0,0 +1,578 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from contextlib import contextmanager
+from datetime import datetime
+from typing import List, Union
+
+import sqlalchemy
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.orm.session import Session, sessionmaker
+from sqlalchemy.orm.strategy_options import _UnboundLoad
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+from submarine.entities.model_registry.model_stages import (
+ STAGE_DELETED_INTERNAL,
+ get_canonical_stage,
+)
+from submarine.exceptions import SubmarineException
+from submarine.store.database.models import (
+ Base,
+ SqlModelVersion,
+ SqlModelVersionTag,
+ SqlRegisteredModel,
+ SqlRegisteredModelTag,
+)
+from submarine.store.model_registry.abstract_store import AbstractStore
+from submarine.utils import extract_db_type_from_uri
+from submarine.utils.validation import (
+ validate_description,
+ validate_model_name,
+ validate_model_version,
+ validate_tag,
+ validate_tags,
+)
+
+_logger = logging.getLogger(__name__)
+
+
+class SqlAlchemyStore(AbstractStore):
+ def __init__(self, db_uri: str) -> None:
+ """
+ Create a database backed store.
+ :param db_uri: The SQLAlchemy database URI string to connect to the database. See
+ the `SQLAlchemy docs
+ <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
+ for format specifications. Submarine supports the dialects ``mysql``.
+ """
+ super(SqlAlchemyStore, self).__init__()
+
+ self.db_uri = db_uri
+ self.db_type = extract_db_type_from_uri(db_uri)
+ self.engine = sqlalchemy.create_engine(db_uri, pool_pre_ping=True)
+ insp = sqlalchemy.inspect(self.engine)
+
+ # Verify that all model registry tables exist.
+ expected_tables = {
+ SqlRegisteredModel.__tablename__,
+ SqlRegisteredModelTag.__tablename__,
+ SqlModelVersion.__tablename__,
+ SqlModelVersionTag.__tablename__,
+ }
+ if len(expected_tables & set(insp.get_table_names())) == 0:
+ SqlAlchemyStore._initialize_tables(self.engine)
+ Base.metadata.bind = self.engine
+ SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
+ self.ManagedSessionMaker = self._get_managed_session_maker(SessionMaker)
+
+ @staticmethod
+ def _initialize_tables(engine: Engine):
+ _logger.info("Creating initial Submarine database tables...")
+ Base.metadata.create_all(engine)
+
+ @staticmethod
+ def _get_managed_session_maker(SessionMaker: sessionmaker):
+ """
+ Creates a factory for producing exception-safe SQLAlchemy sessions that are made available
+ using a context manager. Any session produced by this factory is automatically committed
+ if no exceptions are encountered within its associated context. If an exception is
+ encountered, the session is rolled back. Finally, any session produced by this factory is
+ automatically closed when the session's associated context is exited.
+ """
+
+ @contextmanager
+ def make_managed_session():
+ """Provide a transactional scope around a series of operations."""
+ session: Session = SessionMaker()
+ try:
+ yield session
+ session.commit()
+ except SubmarineException:
+ session.rollback()
+ raise
+ except Exception as e:
+ session.rollback()
+ raise SubmarineException(e)
+ finally:
+ session.close()
+
+ return make_managed_session
+
+ @staticmethod
+ def _get_eager_registered_model_query_options() -> List[_UnboundLoad]:
+ """
+ :return A list of SQLAlchemy query options that can be used to eagerly
+ load the following registered model attributes
+ when fetching a model: ``registered_model_tag``.
+ """
+ return [sqlalchemy.orm.subqueryload(SqlRegisteredModel.tags)]
+
+ @staticmethod
+ def _get_eager_model_version_query_options():
+ """
+ :return: A list of SQLAlchemy query options that can be used to eagerly
+ load the following model version attributes
+ when fetching a model: ``model_version_tag``.
+ """
+ return [sqlalchemy.orm.subqueryload(SqlModelVersion.tags)]
+
+ def _save_to_db(self, session: Session, objs: Union[list, object]) -> None:
+ """
+ Store in db
+ """
+ if type(objs) is list:
+ session.add_all(objs)
+ else:
+ # single object
+ session.add(objs)
+
+ def create_registered_model(
+ self, name: str, description: str = None, tags: List[str] = None
+ ) -> RegisteredModel:
+ """
+ Create a new registered model in backend store.
+ :param name: Name of the new registered model.
+ This is expected to be unique in the backend store.
+ :param description: Description of the registered model.
+ :param tags: A list of tags associated with this registered model.
+ :return: A single object of :py:class:`submarine.entities.model_registry.RegisteredModel`
+ created in the backend.
+ """
+ validate_model_name(name)
+ validate_tags(tags)
+ validate_description(description)
+ with self.ManagedSessionMaker() as session:
+ try:
+ creation_time = datetime.now()
+ registered_model = SqlRegisteredModel(
+ name=name,
+ creation_time=creation_time,
+ last_updated_time=creation_time,
+ description=description,
+ tags=[SqlRegisteredModelTag(tag=tag) for tag in tags or []],
+ )
+ self._save_to_db(session, registered_model)
+ session.flush()
+ return registered_model.to_submarine_entity()
+ except sqlalchemy.exc.IntegrityError as e:
+ raise SubmarineException(
+ f"Registered model (name={name}) already exists.\nError: {str(e)}"
+ )
+
+ @classmethod
+ def _get_sql_registered_model(
+ cls, session: Session, name: str, eager: bool = False
+ ) -> SqlRegisteredModel:
+ """
+ :param eager: If ``True``, eagerly loads the registered model's tags.
+ If ``False``, these attributes are not eagerly loaded and
+ will be loaded when their corresponding object properties
+ are accessed from the resulting ``SqlRegisteredModel`` object.
+ """
+ validate_model_name(name)
+ query_options = cls._get_eager_registered_model_query_options() if eager else []
+ models: List[SqlRegisteredModel] = (
+ session.query(SqlRegisteredModel)
+ .options(*query_options)
+ .filter(SqlRegisteredModel.name == name)
+ .all()
+ )
+
+ if len(models) == 0:
+ raise SubmarineException(f"Registered model with name={name} not found")
+ elif len(models) > 1:
+ raise SubmarineException(
+ f"Expected only 1 registered model with name={name}.\nFound {len(models)}"
+ )
+ else:
+ return models[0]
+
+ def update_registered_model_description(self, name: str, description: str) -> RegisteredModel:
+ """
+ Update description of the registered model.
+ :param name: Registered model name.
+ :param description: New description.
+ :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+ object.
+ """
+ validate_description(description)
+ with self.ManagedSessionMaker() as session:
+ sql_registered_model = self._get_sql_registered_model(session, name)
+ sql_registered_model.description = description
+ sql_registered_model.last_updated_time = datetime.now()
+ self._save_to_db(session, sql_registered_model)
+ session.flush()
+ return sql_registered_model.to_submarine_entity()
+
+ def rename_registered_model(self, name: str, new_name: str) -> RegisteredModel:
+ """
+ Rename the registered model.
+ :param name: Registered model name.
+ :param new_name: New proposed name.
+ :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+ object.
+ """
+ validate_model_name(new_name)
+ with self.ManagedSessionMaker() as session:
+ sql_registered_model = self._get_sql_registered_model(session, name)
+ try:
+ update_time = datetime.now()
+ sql_registered_model.name = new_name
+ sql_registered_model.last_updated_time = update_time
+ for sql_model_version in sql_registered_model.model_versions:
+ sql_model_version.name = new_name
+ sql_model_version.last_updated_time = update_time
+ self._save_to_db(
+ session, [sql_registered_model] + sql_registered_model.model_versions
+ )
+ session.flush()
+ return sql_registered_model.to_submarine_entity()
+ except sqlalchemy.exc.IntegrityError as e:
+ raise SubmarineException(
+ f"Registered Model (name={name}) already exists. Error: {str(e)}"
+ )
+
+ def delete_registered_model(self, name: str) -> None:
+ """
+ Delete the registered model.
+ :param name: Registered model name.
+ :return: None.
+ """
+ with self.ManagedSessionMaker() as session:
+ sql_registered_model = self._get_sql_registered_model(session, name)
+ session.delete(sql_registered_model)
+
+ def list_registered_model(
+ self, filter_str: str = None, filter_tags: List[str] = None
+ ) -> List[RegisteredModel]:
+ """
+ List of all models.
+ :param filter_string: Filter query string, defaults to searching all registered models.
+ :param filter_tags: Filter tags, defaults not to filter any tags.
+ :return: A List of :py:class:`submarine.entities.model_registry.RegisteredModel` objects
+ that satisfy the search expressions.
+ """
+ conditions = []
+ if filter_tags is not None:
+ conditions += [
+ SqlRegisteredModel.tags.any(SqlRegisteredModelTag.tag.contains(tag))
+ for tag in filter_tags
+ ]
+ if filter_str is not None:
+ conditions.append(SqlRegisteredModel.name.startswith(filter_str))
+ with self.ManagedSessionMaker() as session:
+ sql_registered_models = session.query(SqlRegisteredModel).filter(*conditions).all()
+ return [
+ sql_registered_model.to_submarine_entity()
+ for sql_registered_model in sql_registered_models
+ ]
+
+ def get_registered_model(self, name: str) -> RegisteredModel:
+ """
+ Get registered model instance by name.
+ :param name: Registered model name.
+ :return: A single :py:class:`submarine.entities.model_registry.RegisteredModel` object.
+ """
+ with self.ManagedSessionMaker() as session:
+ return self._get_sql_registered_model(session, name, True).to_submarine_entity()
+
+ @classmethod
+ def _get_registered_model_tag(
+ cls, session: Session, name: str, tag: str
+ ) -> SqlRegisteredModelTag:
+ tags = (
+ session.query(SqlRegisteredModelTag)
+ .filter(SqlRegisteredModelTag.name == name, SqlRegisteredModelTag.tag == tag)
+ .all()
+ )
+ if len(tags) == 0:
+ raise SubmarineException(f"Registered model tag with name={name}, tag={tag} not found")
+ elif len(tags) > 1:
+ raise SubmarineException(
+ f"Expected only 1 registered model version tag with name={name}, tag={tag}. Found"
+ f" {len(tags)}."
+ )
+ else:
+ return tags[0]
+
+ def add_registered_model_tag(self, name: str, tag: str) -> None:
+ """
+ Add a tag for the registered model.
+ :param name: registered model name.
+ :param tag: String of tag value.
+ :return: None.
+ """
+ validate_model_name(name)
+ validate_tag(tag)
+ with self.ManagedSessionMaker() as session:
+ # check if registered model exists
+ self._get_sql_registered_model(session, name)
+ session.merge(SqlRegisteredModelTag(name=name, tag=tag))
+
+ def delete_registered_model_tag(self, name: str, tag: str) -> None:
+ """
+ Delete a tag associated with the registered model.
+ :param name: Model name.
+ :param tag: String of tag value.
+ :return: None.
+ """
+ validate_model_name(name)
+ validate_tag(tag)
+ with self.ManagedSessionMaker() as session:
+ # check if registered model exists
+ self._get_sql_registered_model(session, name)
+ existing_tag = self._get_registered_model_tag(session, name, tag)
+ session.delete(existing_tag)
+
+ def create_model_version(
+ self,
+ name: str,
+ source: str,
+ user_id: str,
+ experiment_id: str,
+ dataset: str = None,
+ description: str = None,
+ tags: List[str] = None,
+ ) -> ModelVersion:
+ """
+ Create a new version of the registered model
+ :param name: Registered model name.
+ :param source: Source path where this version of model is stored.
+ :param user_id: User ID from server that created this model
+ :param experiment_id: Experiment ID which this model is created.
+ :param dataset: Dataset which this version of model is used.
+ :param description: Description of this version.
+ :param tags: A list of string associated with this version of model.
+ :return: A single object of :py:class:`submarine.entities.model_registry.ModelMetadata`
+ created in the backend.
+ """
+
+ def next_version(sql_registered_model: SqlRegisteredModel) -> int:
+ if sql_registered_model.model_versions:
+ return max([m.version for m in sql_registered_model.model_versions]) + 1
+ else:
+ return 1
+
+ validate_model_name(name)
+ validate_description(description)
+ validate_tags(tags)
+ with self.ManagedSessionMaker() as session:
+ try:
+ creation_time = datetime.now()
+ sql_registered_model = self._get_sql_registered_model(session, name)
+ sql_registered_model.last_updated_time = creation_time
+ model_metadata = SqlModelVersion(
+ name=name,
+ version=next_version(sql_registered_model),
+ source=source,
+ user_id=user_id,
+ experiment_id=experiment_id,
+ creation_time=creation_time,
+ last_updated_time=creation_time,
+ dataset=dataset,
+ description=description,
+ tags=[SqlModelVersionTag(tag=tag) for tag in tags or []],
+ )
+ self._save_to_db(session, [sql_registered_model, model_metadata])
+ session.flush()
+ return model_metadata.to_submarine_entity()
+ except sqlalchemy.exc.IntegrityError:
+ raise SubmarineException(f"Model create error (name={name}).")
+
+ @classmethod
+ def _get_sql_model_version(
+ cls, session: Session, name: str, version: int, eager: bool = False
+ ) -> SqlModelVersion:
+ """
+ :param eager: If ``True``, eagerly loads the model's tags.
+ If ``False``, these attributes are not eagerly loaded and
+ will be loaded when their corresponding object properties
+ are accessed from the resulting ``SqlModelVersion`` object.
+ """
+ validate_model_name(name)
+ validate_model_version(version)
+ query_options = cls._get_eager_model_version_query_options() if eager else []
+ conditions = [
+ SqlModelVersion.name == name,
+ SqlModelVersion.version == version,
+ SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL,
+ ]
+
+ models: List[SqlModelVersion] = (
+ session.query(SqlModelVersion).options(*query_options).filter(*conditions).all()
+ )
+ if len(models) == 0:
+ raise SubmarineException(f"Model Version (name={name}, version={version}) not found.")
+ elif len(models) > 1:
+ raise SubmarineException(
+ f"Expected only 1 model version with (name={name}, version={version}). Found"
+ f" {len(models)}."
+ )
+ else:
+ return models[0]
+
+ def update_model_version_description(
+ self, name: str, version: int, description: str
+ ) -> ModelVersion:
+ """
+ Update description associated with the version of model in backend.
+ :param name: Registered model name.
+ :param version: Version of the registered model.
+ :param description: New model description.
+ :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+ """
+ validate_description(description)
+ with self.ManagedSessionMaker() as session:
+ update_time = datetime.now()
+ sql_model = self._get_sql_model_version(session, name, version)
+ sql_model.description = description
+ sql_model.last_updated_time = update_time
+ self._save_to_db(session, sql_model)
+ return sql_model.to_submarine_entity()
+
+ def transition_model_version_stage(self, name: str, version: int, stage: str) -> ModelVersion:
+ """
+ Update this version's stage.
+ :param name: Registered model name.
+ :param version: Version of the registered model.
+ :param stage: New desired stage for this version of registered model.
+ :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+ """
+ with self.ManagedSessionMaker() as session:
+ last_updated_time = datetime.now()
+
+ sql_model_version = self._get_sql_model_version(session, name, version)
+ sql_model_version.current_stage = get_canonical_stage(stage)
+ sql_model_version.last_updated_time = last_updated_time
+ sql_registered_model = sql_model_version.registered_model
+ sql_registered_model.last_updated_time = last_updated_time
+ self._save_to_db(session, [sql_model_version, sql_registered_model])
+ return sql_model_version.to_submarine_entity()
+
+ def delete_model_version(self, name: str, version: int) -> None:
+ """
+ Delete model version in backend.
+ :param name: Registered model name.
+ :param version: Version of the registered model.
+ :return: None
+ """
+ with self.ManagedSessionMaker() as session:
+ updated_time = datetime.now()
+ sql_model_version = self._get_sql_model_version(session, name, version)
+ sql_registered_model = sql_model_version.registered_model
+ sql_registered_model.last_updated_time = updated_time
+ session.delete(sql_model_version)
+ self._save_to_db(session, sql_registered_model)
+ session.flush()
+
+ def get_model_version(self, name: str, version: int) -> ModelVersion:
+ """
+ Get the model by name and version.
+ :param name: Registered model name.
+ :param version: Version of registered model.
+ :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+ """
+ with self.ManagedSessionMaker() as session:
+ sql_model_version = self._get_sql_model_version(session, name, version, True)
+ return sql_model_version.to_submarine_entity()
+
+ def list_model_versions(self, name: str, filter_tags: list = None) -> List[ModelVersion]:
+ """
+ List of all models that satisfy the filter criteria.
+ :param name: Registered model name.
+ :param filter_tags: Filter tags, defaults not to filter any tags.
+ :return: A List of :py:class:`submarine.entities.model_registry.ModelVersion` objects
+ that satisfy the search expressions.
+ """
+ conditions = [SqlModelVersion.name == name]
+ if filter_tags is not None:
+ conditions += [
+ SqlModelVersion.tags.any(SqlModelVersionTag.tag.contains(tag))
+ for tag in filter_tags
+ ]
+ with self.ManagedSessionMaker() as session:
+ sql_models = session.query(SqlModelVersion).filter(*conditions).all()
+ return [sql_model.to_submarine_entity() for sql_model in sql_models]
+
+ def get_model_version_uri(self, name: str, version: int) -> str:
+ """
+ Get the location in Model registry for this version.
+ :param name: Registered model name.
+ :param version: Version of registered model.
+ :return: A single URI location.
+ """
+ with self.ManagedSessionMaker() as session:
+ sql_model = self._get_sql_model_version(session, name, version)
+ return sql_model.to_submarine_entity().source
+
+ @classmethod
+ def _get_sql_model_version_tag(
+ cls, session: Session, name: str, version: int, tag: str
+ ) -> SqlModelVersionTag:
+ tags = (
+ session.query(SqlModelVersionTag)
+ .filter(
+ SqlModelVersionTag.name == name,
+ SqlModelVersionTag.name == name,
+ SqlModelVersionTag.version == version,
+ SqlModelVersionTag.tag == tag,
+ )
+ .all()
+ )
+ if len(tags) == 0:
+ raise SubmarineException(
+ f"Model version tag with name={name}, version={version}, tag={tag} not found"
+ )
+ elif len(tags) > 1:
+ raise SubmarineException(
+ f"Expected only 1 model version tag with name={name}, version={version}, tag={tag}."
+ f" Found {len(tags)}."
+ )
+ else:
+ return tags[0]
+
+ def add_model_version_tag(self, name: str, version: int, tag: str) -> None:
+ """
+ Add a tag for this version of model.
+ :param name: Registered model name.
+ :param version: Version of registered model.
+ :param tag: String of tag value.
+ :return: None.
+ """
+ validate_model_name(name)
+ validate_model_version(version)
+ validate_tag(tag)
+ with self.ManagedSessionMaker() as session:
+ # check if model version exists
+ self._get_sql_model_version(session, name, version)
+ session.merge(SqlModelVersionTag(name=name, version=version, tag=tag))
+
+ def delete_model_version_tag(self, name: str, version: int, tag: str) -> None:
+ """
+ Delete a tag associated with this version of model.
+ :param name: Registered model name.
+ :param version: Version of registered model.
+ :param tag: String of tag value.
+ :return: None.
+ """
+ validate_model_name(name)
+ validate_model_version(version)
+ validate_tag(tag)
+ with self.ManagedSessionMaker() as session:
+ # check if model version exists
+ self._get_sql_model_version(session, name, version)
+ existing_tag = self._get_sql_model_version_tag(session, name, version, tag)
+ session.delete(existing_tag)
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/__init__.py b/submarine-sdk/pysubmarine/submarine/tracking/__init__.py
index 24e3314..a316524 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/__init__.py
@@ -14,17 +14,10 @@
# limitations under the License.
from submarine.tracking.client import SubmarineClient
-from submarine.tracking.utils import (
- _JOB_ID_ENV_VAR,
- _TRACKING_URI_ENV_VAR,
- get_tracking_uri,
- set_tracking_uri,
-)
+from submarine.tracking.utils import _JOB_ID_ENV_VAR, _TRACKING_URI_ENV_VAR
__all__ = [
"SubmarineClient",
- "get_tracking_uri",
- "set_tracking_uri",
"_TRACKING_URI_ENV_VAR",
"_JOB_ID_ENV_VAR",
]
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/client.py b/submarine-sdk/pysubmarine/submarine/tracking/client.py
index b342b24..2ee9b09 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/client.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/client.py
@@ -14,6 +14,7 @@
# limitations under the License.
import time
+import submarine
from submarine.entities import Metric, Param
from submarine.tracking import utils
from submarine.utils.validation import validate_metric, validate_param
@@ -24,15 +25,15 @@ class SubmarineClient(object):
Client of an submarine Tracking Server that creates and manages experiments and runs.
"""
- def __init__(self, tracking_uri=None):
+ def __init__(self, db_uri=None):
"""
- :param tracking_uri: Address of local or remote tracking server. If not provided, defaults
- to the service set by ``submarine.tracking.set_tracking_uri``. See
+ :param db_uri: Address of local or remote tracking server. If not provided, defaults
+ to the service set by ``submarine.tracking.set_db_uri``. See
`Where Runs Get Recorded <../tracking.html#where-runs-get-recorded>`_
for more info.
"""
- self.tracking_uri = tracking_uri or utils.get_tracking_uri()
- self.store = utils.get_sqlalchemy_store(self.tracking_uri)
+ self.db_uri = db_uri or submarine.get_db_uri()
+ self.store = utils.get_sqlalchemy_store(self.db_uri)
def log_metric(self, job_id, key, value, worker_index, timestamp=None, step=None):
"""
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/utils.py b/submarine-sdk/pysubmarine/submarine/tracking/utils.py
index af169ec..ec0ec14 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/utils.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/utils.py
@@ -19,7 +19,6 @@ import json
import os
import uuid
-from submarine.store import DEFAULT_SUBMARINE_JDBC_URL
from submarine.store.sqlalchemy_store import SqlAlchemyStore
from submarine.utils import env
@@ -42,39 +41,6 @@ _TRACKING_PASSWORD_ENV_VAR = "SUBMARINE_TRACKING_PASSWORD"
_TRACKING_TOKEN_ENV_VAR = "SUBMARINE_TRACKING_TOKEN"
_TRACKING_INSECURE_TLS_ENV_VAR = "SUBMARINE_TRACKING_INSECURE_TLS"
-_tracking_uri = None
-
-
-def is_tracking_uri_set():
- """Returns True if the tracking URI has been set, False otherwise."""
- if _tracking_uri or env.get_env(_TRACKING_URI_ENV_VAR):
- return True
- return False
-
-
-def set_tracking_uri(uri):
- """
- Set the tracking server URI. This does not affect the
- currently active run (if one exists), but takes effect for successive runs.
- """
- global _tracking_uri
- _tracking_uri = uri
-
-
-def get_tracking_uri():
- """
- Get the current tracking URI. This may not correspond to the tracking URI of
- the currently active run, since the tracking URI can be updated via ``set_tracking_uri``.
- :return: The tracking URI.
- """
- global _tracking_uri
- if _tracking_uri is not None:
- return _tracking_uri
- elif env.get_env(_TRACKING_URI_ENV_VAR) is not None:
- return env.get_env(_TRACKING_URI_ENV_VAR)
- else:
- return DEFAULT_SUBMARINE_JDBC_URL
-
def get_job_id():
"""
diff --git a/submarine-sdk/pysubmarine/submarine/utils/__init__.py b/submarine-sdk/pysubmarine/submarine/utils/__init__.py
index 6f2b95c..4908ba6 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/__init__.py
@@ -16,6 +16,7 @@
from six.moves import urllib
from submarine.exceptions import SubmarineException
+from submarine.utils.db_utils import get_db_uri, set_db_uri
def extract_db_type_from_uri(db_uri):
@@ -35,3 +36,9 @@ def extract_db_type_from_uri(db_uri):
raise SubmarineException(error_msg)
return db_type
+
+
+__all__ = [
+ "get_db_uri",
+ "set_db_uri",
+]
diff --git a/submarine-sdk/pysubmarine/submarine/utils/__init__.py b/submarine-sdk/pysubmarine/submarine/utils/db_utils.py
similarity index 50%
copy from submarine-sdk/pysubmarine/submarine/utils/__init__.py
copy to submarine-sdk/pysubmarine/submarine/utils/db_utils.py
index 6f2b95c..b23ce2d 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/db_utils.py
@@ -13,25 +13,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six.moves import urllib
+from submarine.store import DEFAULT_SUBMARINE_JDBC_URL
+from submarine.utils import env
-from submarine.exceptions import SubmarineException
+_DB_URI_ENV_VAR = "SUBMARINE_DB_URI"
-def extract_db_type_from_uri(db_uri):
+_db_uri = None
+
+
+def is_db_uri_set():
+ """Returns True if the DB URI has been set, False otherwise."""
+ if _db_uri or env.get_env(_DB_URI_ENV_VAR):
+ return True
+ return False
+
+
+def set_db_uri(uri):
"""
- Parse the specified DB URI to extract the database type. Confirm the database type is
- supported. If a driver is specified, confirm it passes a plausible regex.
+ Set the DB URI. This does not affect the currently active run (if one exists),
+ but takes effect for successive runs.
"""
- scheme = urllib.parse.urlparse(db_uri).scheme
- scheme_plus_count = scheme.count("+")
+ global _db_uri
+ _db_uri = uri
- if scheme_plus_count == 0:
- db_type = scheme
- elif scheme_plus_count == 1:
- db_type, _ = scheme.split("+")
- else:
- error_msg = "Invalid database URI: '%s'. %s" % (db_uri, "INVALID_DB_URI_MSG")
- raise SubmarineException(error_msg)
- return db_type
+def get_db_uri():
+ """
+ Get the current DB URI.
+ :return: The DB URI.
+ """
+ global _db_uri
+ if _db_uri is not None:
+ return _db_uri
+ elif env.get_env(_DB_URI_ENV_VAR) is not None:
+ return env.get_env(_DB_URI_ENV_VAR)
+ else:
+ return DEFAULT_SUBMARINE_JDBC_URL
diff --git a/submarine-sdk/pysubmarine/submarine/utils/validation.py b/submarine-sdk/pysubmarine/submarine/utils/validation.py
index 60dd9b0..00e2c98 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/validation.py
+++ b/submarine-sdk/pysubmarine/submarine/utils/validation.py
@@ -19,6 +19,7 @@ import numbers
import posixpath
import re
from datetime import datetime
+from typing import List, Optional
from submarine.exceptions import SubmarineException
from submarine.store.database.db_types import DATABASE_ENGINES
@@ -116,8 +117,46 @@ def validate_param(key, value):
_validate_length_limit("Param value", MAX_PARAM_VAL_LENGTH, str(value))
+def validate_tags(tags: Optional[List[str]]) -> None:
+ if tags is not None and not isinstance(tags, list):
+ raise SubmarineException("parameter tags must be list or None.")
+ for tag in tags or []:
+ validate_tag(tag)
+
+
+def validate_tag(tag: str) -> None:
+ """Check that `tag` is a valid tag value and raise an exception if it isn't."""
+ # Reuse param & metric check.
+ if tag is None or tag == "":
+ raise SubmarineException("Tag cannot be empty.")
+ if not _VALID_PARAM_AND_METRIC_NAMES.match(tag):
+ raise SubmarineException("Invalid tag name: '%s'. %s" % (tag, _BAD_CHARACTERS_MESSAGE))
+
+
+def validate_model_name(name: str) -> None:
+ if name is None or name == "":
+ raise SubmarineException("Model name cannot be empty.")
+
+
+def validate_model_version(version: int) -> None:
+ if not isinstance(version, int):
+ raise SubmarineException(f"Model version must be an integer, got {type(version)} type.")
+ elif version < 1:
+ raise SubmarineException(f"Model version must bigger than 0, but got {version}")
+
+
+def validate_description(description: Optional[str]) -> None:
+ if not isinstance(description, str) and description is not None:
+ raise SubmarineException(f"Description must be String or None, but got {type(description)}")
+ if isinstance(description, str) and len(description) > 5000:
+ raise SubmarineException(
+ f"Description must less than 5000 words, but got {len(description)}"
+ )
+
+
def _validate_db_type_string(db_type):
"""validates db_type parsed from DB URI is supported"""
if db_type not in DATABASE_ENGINES:
- error_msg = "Invalid database engine: '%s'. '%s'" % (db_type, _UNSUPPORTED_DB_TYPE_MSG)
- raise SubmarineException(error_msg)
+ raise SubmarineException(
+ f"Invalid database engine: '{db_type}'. '{_UNSUPPORTED_DB_TYPE_MSG}'"
+ )
diff --git a/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py b/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py
index 6611fb4..9cead79 100644
--- a/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py
+++ b/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py
@@ -15,21 +15,20 @@
from datetime import datetime
-from submarine.entities.model_registry.model_tag import ModelTag
-from submarine.entities.model_registry.model_version import ModelVersion
-from submarine.entities.model_registry.model_version_stages import STAGE_NONE
+from submarine.entities.model_registry import ModelVersion, ModelVersionTag
+from submarine.entities.model_registry.model_stages import STAGE_NONE
class TestModelVersion:
default_data = {
"name": "test",
"version": 1,
+ "source": "path/to/source",
"user_id": "admin",
"experiment_id": "experiment_1",
"current_stage": STAGE_NONE,
"creation_time": datetime.now(),
"last_updated_time": datetime.now(),
- "source": "path/to/source",
"dataset": "test",
"description": "registered model description",
"tags": [],
@@ -37,42 +36,42 @@ class TestModelVersion:
def _check(
self,
- model_version,
+ model_metadata,
name,
version,
+ source,
user_id,
experiment_id,
current_stage,
creation_time,
last_updated_time,
- source,
dataset,
description,
tags,
):
- isinstance(model_version, ModelVersion)
- assert model_version.name == name
- assert model_version.version == version
- assert model_version.user_id == user_id
- assert model_version.experiment_id == experiment_id
- assert model_version.current_stage == current_stage
- assert model_version.creation_time == creation_time
- assert model_version.last_updated_time == last_updated_time
- assert model_version.source == source
- assert model_version.dataset == dataset
- assert model_version.description == description
- assert model_version.tags == tags
+ isinstance(model_metadata, ModelVersion)
+ assert model_metadata.name == name
+ assert model_metadata.version == version
+ assert model_metadata.source == source
+ assert model_metadata.user_id == user_id
+ assert model_metadata.experiment_id == experiment_id
+ assert model_metadata.current_stage == current_stage
+ assert model_metadata.creation_time == creation_time
+ assert model_metadata.last_updated_time == last_updated_time
+ assert model_metadata.dataset == dataset
+ assert model_metadata.description == description
+ assert model_metadata.tags == tags
def test_creation_and_hydration(self):
mv = ModelVersion(
self.default_data["name"],
self.default_data["version"],
+ self.default_data["source"],
self.default_data["user_id"],
self.default_data["experiment_id"],
self.default_data["current_stage"],
self.default_data["creation_time"],
self.default_data["last_updated_time"],
- self.default_data["source"],
self.default_data["dataset"],
self.default_data["description"],
self.default_data["tags"],
@@ -81,30 +80,30 @@ class TestModelVersion:
mv,
self.default_data["name"],
self.default_data["version"],
+ self.default_data["source"],
self.default_data["user_id"],
self.default_data["experiment_id"],
self.default_data["current_stage"],
self.default_data["creation_time"],
self.default_data["last_updated_time"],
- self.default_data["source"],
self.default_data["dataset"],
self.default_data["description"],
self.default_data["tags"],
)
def test_with_tags(self):
- tag1 = ModelTag("tag1")
- tag2 = ModelTag("tag2")
+ tag1 = ModelVersionTag("tag1")
+ tag2 = ModelVersionTag("tag2")
tags = [tag1, tag2]
mv = ModelVersion(
self.default_data["name"],
self.default_data["version"],
+ self.default_data["source"],
self.default_data["user_id"],
self.default_data["experiment_id"],
self.default_data["current_stage"],
self.default_data["creation_time"],
self.default_data["last_updated_time"],
- self.default_data["source"],
self.default_data["dataset"],
self.default_data["description"],
tags,
@@ -113,12 +112,12 @@ class TestModelVersion:
mv,
self.default_data["name"],
self.default_data["version"],
+ self.default_data["source"],
self.default_data["user_id"],
self.default_data["experiment_id"],
self.default_data["current_stage"],
self.default_data["creation_time"],
self.default_data["last_updated_time"],
- self.default_data["source"],
self.default_data["dataset"],
self.default_data["description"],
[t.tag for t in tags],
diff --git a/submarine-sdk/pysubmarine/tests/entities/model_registry/test_registered_model.py b/submarine-sdk/pysubmarine/tests/entities/model_registry/test_registered_model.py
index e1fa2af..138eb66 100644
--- a/submarine-sdk/pysubmarine/tests/entities/model_registry/test_registered_model.py
+++ b/submarine-sdk/pysubmarine/tests/entities/model_registry/test_registered_model.py
@@ -15,8 +15,7 @@
from datetime import datetime
-from submarine.entities.model_registry.registered_model import RegisteredModel
-from submarine.entities.model_registry.registered_model_tag import RegisteredModelTag
+from submarine.entities.model_registry import RegisteredModel, RegisteredModelTag
class TestRegisteredModel:
diff --git a/dev-support/style-check/python/mypy-requirements.txt b/submarine-sdk/pysubmarine/tests/store/__init__.py
similarity index 87%
copy from dev-support/style-check/python/mypy-requirements.txt
copy to submarine-sdk/pysubmarine/tests/store/__init__.py
index 6a581e3..a6eb1b5 100644
--- a/dev-support/style-check/python/mypy-requirements.txt
+++ b/submarine-sdk/pysubmarine/tests/store/__init__.py
@@ -12,9 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-mypy==0.910
-types-requests==2.25.6
-types-certifi==2020.4.0
-types-six==1.16.1
-types-python-dateutil==2.8.0
diff --git a/dev-support/style-check/python/mypy-requirements.txt b/submarine-sdk/pysubmarine/tests/store/model_registry/__init__.py
similarity index 87%
copy from dev-support/style-check/python/mypy-requirements.txt
copy to submarine-sdk/pysubmarine/tests/store/model_registry/__init__.py
index 6a581e3..a6eb1b5 100644
--- a/dev-support/style-check/python/mypy-requirements.txt
+++ b/submarine-sdk/pysubmarine/tests/store/model_registry/__init__.py
@@ -12,9 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-mypy==0.910
-types-requests==2.25.6
-types-certifi==2020.4.0
-types-six==1.16.1
-types-python-dateutil==2.8.0
diff --git a/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py b/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py
new file mode 100644
index 0000000..6d57c2e
--- /dev/null
+++ b/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py
@@ -0,0 +1,739 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from datetime import datetime
+from typing import List
+
+import freezegun
+import pytest
+from freezegun import freeze_time
+
+import submarine
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+from submarine.entities.model_registry.model_stages import (
+ STAGE_ARCHIVED,
+ STAGE_DEVELOPING,
+ STAGE_NONE,
+ STAGE_PRODUCTION,
+)
+from submarine.exceptions import SubmarineException
+from submarine.store.database import models
+from submarine.store.model_registry.sqlalchemy_store import SqlAlchemyStore
+
+freezegun.configure(default_ignore_list=["threading", "tensorflow"])
+
+
+@pytest.mark.e2e
+class TestSqlAlchemyStore(unittest.TestCase):
+ def setUp(self):
+ submarine.set_db_uri(
+ "mysql+pymysql://submarine_test:password_test@localhost:3306/submarine_test"
+ )
+ self.db_uri = submarine.get_db_uri()
+ self.store = SqlAlchemyStore(self.db_uri)
+
+ def tearDown(self):
+ submarine.set_db_uri(None)
+ models.Base.metadata.drop_all(self.store.engine)
+
+ def test_create_registered_model(self):
+ name1 = "test_create_RM_1"
+ rm1 = self.store.create_registered_model(name1)
+ self.assertEqual(rm1.name, name1)
+ self.assertEqual(rm1.description, None)
+
+ # error in duplicate
+ with self.assertRaises(SubmarineException):
+ self.store.create_registered_model(name1)
+
+ # test create with tags
+ name2 = "test_create_RM_2"
+ tags = ["tag1", "tag2"]
+ rm2 = self.store.create_registered_model(name2, tags=tags)
+ rm2d = self.store.get_registered_model(name2)
+ self.assertEqual(rm2.name, name2)
+ self.assertEqual(rm2.tags, tags)
+ self.assertEqual(rm2d.name, name2)
+ self.assertEqual(rm2d.tags, tags)
+
+ # test create with description
+ name3 = "test_create_RM_3"
+ description = "A test description."
+ rm3 = self.store.create_registered_model(name3, description)
+ rm3d = self.store.get_registered_model(name3)
+ self.assertEqual(rm3.name, name3)
+ self.assertEqual(rm3.description, description)
+ self.assertEqual(rm3d.name, name3)
+ self.assertEqual(rm3d.description, description)
+
+ # invalid model name
+ with self.assertRaises(SubmarineException):
+ self.store.create_registered_model(None)
+ with self.assertRaises(SubmarineException):
+ self.store.create_registered_model("")
+
+ def test_update_registered_model_description(self):
+ name = "test_update_RM"
+ rm1 = self.store.create_registered_model(name)
+ rm1d = self.store.get_registered_model(name)
+ self.assertEqual(rm1.name, name)
+ self.assertEqual(rm1d.description, None)
+
+ # update description
+ fake_datetime = datetime.strptime("2021-11-11 11:11:11.111000", "%Y-%m-%d %H:%M:%S.%f")
+ with freeze_time(fake_datetime):
+ rm2 = self.store.update_registered_model_description(name, "New description.")
+ rm2d = self.store.get_registered_model(name)
+ self.assertEqual(rm2.name, name)
+ self.assertEqual(rm2.description, "New description.")
+ self.assertEqual(rm2d.name, name)
+ self.assertEqual(rm2d.description, "New description.")
+ self.assertEqual(rm2d.last_updated_time, fake_datetime)
+
+ def test_rename_registered_model(self):
+ name = "test_rename_RM"
+ new_name = "test_rename_RN_new"
+ rm = self.store.create_registered_model(name)
+ self.store.create_model_version(name, "path/to/source1", "test", "application_1234")
+ self.store.create_model_version(name, "path/to/source2", "test", "application_1235")
+ mv1d = self.store.get_model_version(name, 1)
+ mv2d = self.store.get_model_version(name, 2)
+ self.assertEqual(rm.name, name)
+ self.assertEqual(mv1d.name, name)
+ self.assertEqual(mv2d.name, name)
+
+ # test renaming registered model also updates its models
+ self.store.rename_registered_model(name, new_name)
+ rm = self.store.get_registered_model(new_name)
+ mv1d = self.store.get_model_version(new_name, 1)
+ mv2d = self.store.get_model_version(new_name, 2)
+ self.assertEqual(rm.name, new_name)
+ self.assertEqual(mv1d.name, new_name)
+ self.assertEqual(mv2d.name, new_name)
+
+ # test accessing the registered model with the original name will fail
+ with self.assertRaises(SubmarineException):
+ self.store.rename_registered_model(name, name)
+
+ # invalid name will fail
+ with self.assertRaises(SubmarineException):
+ self.store.rename_registered_model(name, None)
+ with self.assertRaises(SubmarineException):
+ self.store.rename_registered_model(name, "")
+
+ def test_delete_registered_model(self):
+ name1 = "test_delete_RM"
+ name2 = "test_delete_RM_2"
+ rm_tags = ["rm_tag1", "rm_tag2"]
+ rm1 = self.store.create_registered_model(name1, tags=rm_tags)
+ rm2 = self.store.create_registered_model(name2, tags=rm_tags)
+ mv_tags = ["mv_tag1", "mv_tag2"]
+ rm1mv1 = self.store.create_model_version(
+ rm1.name, "path/to/source1", "test", "application_1234", tags=mv_tags
+ )
+ rm2mv1 = self.store.create_model_version(
+ rm2.name, "path/to/source2", "test", "application_1234", tags=mv_tags
+ )
+
+ # check store
+ rm1d = self.store.get_registered_model(rm1.name)
+ self.assertEqual(rm1d.name, name1)
+ self.assertEqual(rm1d.tags, rm_tags)
+ rm1mv1d = self.store.get_model_version(rm1mv1.name, rm1mv1.version)
+ self.assertEqual(rm1mv1d.name, name1)
+ self.assertEqual(rm1mv1d.tags, mv_tags)
+
+ # delete registered model
+ self.store.delete_registered_model(rm1.name)
+
+ # cannot get model
+ with self.assertRaises(SubmarineException):
+ self.store.get_registered_model(rm1.name)
+
+ # cannot delete it again
+ with self.assertRaises(SubmarineException):
+ self.store.delete_registered_model(rm1.name)
+
+ # registered model tag are cascade deleted with the registered model
+ for tag in rm_tags:
+ with self.assertRaises(SubmarineException):
+ self.store.delete_registered_model_tag(rm1.name, tag)
+
+ # models are cascade deleted with the registered model
+ with self.assertRaises(SubmarineException):
+ self.store.get_model_version(rm1mv1.name, rm1mv1.version)
+
+ # model tags are cascade deleted with the registered model
+ for tag in rm_tags:
+ with self.assertRaises(SubmarineException):
+ self.store.delete_model_version_tag(rm1mv1.name, rm1mv1.version, tag)
+
+ # Other registered models and model versions are not affected
+ rm2d = self.store.get_registered_model(rm2.name)
+ self.assertEqual(rm2d.name, rm2.name)
+ self.assertEqual(rm2d.tags, rm2.tags)
+ rm2mv1 = self.store.get_model_version(rm2mv1.name, rm2mv1.version)
+ self.assertEqual(rm2mv1.name, rm2mv1.name)
+ self.assertEqual(rm2mv1.tags, rm2mv1.tags)
+
+ def _compare_registered_model_names(
+ self, results: List[RegisteredModel], rms: List[RegisteredModel]
+ ):
+ result_names = set([result.name for result in results])
+ rm_names = set([rm.name for rm in rms])
+
+ self.assertEqual(result_names, rm_names)
+
+ def test_list_registered_model(self):
+ rms = [self.store.create_registered_model(f"test_list_RM_{i}") for i in range(10)]
+
+ results = self.store.list_registered_model()
+ self.assertEqual(len(results), 10)
+ self._compare_registered_model_names(results, rms)
+
+ def test_list_registered_model_filter_with_string(self):
+ rms = [
+ self.store.create_registered_model("A"),
+ self.store.create_registered_model("AB"),
+ self.store.create_registered_model("B"),
+ self.store.create_registered_model("ABA"),
+ self.store.create_registered_model("AAA"),
+ ]
+
+ results = self.store.list_registered_model(filter_str="A")
+ self.assertEqual(len(results), 4)
+ self._compare_registered_model_names(rms[:2] + rms[3:], results)
+
+ results = self.store.list_registered_model(filter_str="AB")
+ self.assertEqual(len(results), 2)
+ self._compare_registered_model_names([rms[1], rms[3]], results)
+
+ results = self.store.list_registered_model(filter_str="ABA")
+ self.assertEqual(len(results), 1)
+ self._compare_registered_model_names([rms[3]], results)
+
+ results = self.store.list_registered_model(filter_str="ABC")
+ self.assertEqual(len(results), 0)
+ self.assertEqual(results, [])
+
+ def test_list_registered_model_filter_with_tags(self):
+ tags = ["tag1", "tag2", "tag3"]
+ rms = [
+ self.store.create_registered_model("test1"),
+ self.store.create_registered_model("test2", tags=tags[0:1]),
+ self.store.create_registered_model("test3", tags=tags[1:2]),
+ self.store.create_registered_model("test4", tags=[tags[0], tags[2]]),
+ self.store.create_registered_model("test5", tags=tags),
+ ]
+
+ results = self.store.list_registered_model(filter_tags=tags[0:1])
+ self.assertEqual(len(results), 3)
+ self._compare_registered_model_names(results, [rms[1], rms[3], rms[4]])
+
+ results = self.store.list_registered_model(filter_tags=tags[0:2])
+ self.assertEqual(len(results), 1)
+ self._compare_registered_model_names(results, [rms[-1]])
+
+ # empty result
+ other_tag = ["tag4"]
+ results = self.store.list_registered_model(filter_tags=other_tag)
+ self.assertEqual(len(results), 0)
+ self.assertEqual(results, [])
+
+ # empty result
+ results = self.store.list_registered_model(filter_tags=tags + other_tag)
+ self.assertEqual(len(results), 0)
+ self.assertEqual(results, [])
+
+ def test_list_registered_model_filter_both(self):
+ tags = ["tag1", "tag2", "tag3"]
+ rms = [
+ self.store.create_registered_model("A"),
+ self.store.create_registered_model("AB", tags=[tags[0]]),
+ self.store.create_registered_model("B", tags=[tags[1]]),
+ self.store.create_registered_model("ABA", tags=[tags[0], tags[2]]),
+ self.store.create_registered_model("AAA", tags=tags),
+ ]
+
+ results = self.store.list_registered_model()
+ self.assertEqual(len(results), 5)
+ self._compare_registered_model_names(results, rms)
+
+ results = self.store.list_registered_model(filter_str="A", filter_tags=[tags[0]])
+ self.assertEqual(len(results), 3)
+ self._compare_registered_model_names(results, [rms[1], rms[3], rms[4]])
+
+ results = self.store.list_registered_model(filter_str="AB", filter_tags=[tags[0]])
+ self.assertEqual(len(results), 2)
+ self._compare_registered_model_names(results, [rms[1], rms[3]])
+
+ results = self.store.list_registered_model(filter_str="AAA", filter_tags=tags)
+ self.assertEqual(len(results), 1)
+ self._compare_registered_model_names(results, [rms[-1]])
+
+ @freeze_time("2021-11-11 11:11:11.111000")
+ def test_get_registered_model(self):
+ name = "test_get_RM"
+ tags = ["tag1", "tag2"]
+ fake_datetime = datetime.now()
+ rm = self.store.create_registered_model(name, tags=tags)
+ self.assertEqual(rm.name, name)
+
+ rmd = self.store.get_registered_model(name)
+ self.assertEqual(rmd.name, name)
+ self.assertEqual(rmd.creation_time, fake_datetime)
+ self.assertEqual(rmd.last_updated_time, fake_datetime)
+ self.assertEqual(rmd.description, None)
+ self.assertEqual(rmd.tags, tags)
+
+ def test_add_registered_model_tag(self):
+ name1 = "test_add_RM_tag"
+ name2 = "test_add_RM_tag_2"
+ tags = ["tag1", "tag2"]
+ rm1 = self.store.create_registered_model(name1, tags=tags)
+ rm2 = self.store.create_registered_model(name2, tags=tags)
+ new_tag = "new tag"
+ self.store.add_registered_model_tag(name1, new_tag)
+ rmd = self.store.get_registered_model(name1)
+ all_tags = [new_tag] + tags
+ self.assertEqual(rmd.tags, all_tags)
+
+ # test add the same tag
+ same_tag = "tag1"
+ self.store.add_registered_model_tag(name1, same_tag)
+ rm1d = self.store.get_registered_model(rm1.name)
+ self.assertEqual(rm1d.tags, all_tags)
+
+ # does not affect other models
+ rm2d = self.store.get_registered_model(rm2.name)
+ self.assertEqual(rm2d.tags, tags)
+
+ # cannot set invalid tag
+ with self.assertRaises(SubmarineException):
+ self.store.add_registered_model_tag(rm1.name, None)
+ with self.assertRaises(SubmarineException):
+ self.store.add_registered_model_tag(rm1.name, "")
+
+ # cannot use invalid model name
+ with self.assertRaises(SubmarineException):
+ self.store.add_registered_model_tag(None, new_tag)
+
+ # cannot set tag on deleted registered model
+ self.store.delete_registered_model(rm1.name)
+ with self.assertRaises(SubmarineException):
+ new_tag = "new tag2"
+ self.store.add_registered_model_tag(name1, new_tag)
+
+ def test_delete_registered_model_tag(self):
+ name1 = "test_delete_RM_tag"
+ name2 = "test_delete_RM_tag_2"
+ tags = ["tag1", "tag2"]
+ rm1 = self.store.create_registered_model(name1, tags=tags)
+ rm2 = self.store.create_registered_model(name2, tags=tags)
+ new_tag = "new tag"
+ self.store.add_registered_model_tag(rm1.name, new_tag)
+ self.store.delete_registered_model_tag(rm1.name, new_tag)
+ rm1d = self.store.get_registered_model(rm1.name)
+ self.assertEqual(rm1d.tags, tags)
+
+ # delete tag that is already deleted
+ with self.assertRaises(SubmarineException):
+ self.store.delete_registered_model_tag(rm1.name, new_tag)
+ rm1d = self.store.get_registered_model(rm1.name)
+ self.assertEqual(rm1d.tags, tags)
+
+ # does not affect other models
+ rm2d = self.store.get_registered_model(rm2.name)
+ self.assertEqual(rm2d.tags, tags)
+
+ # Cannot delete invalid key
+ with self.assertRaises(SubmarineException):
+ self.store.delete_registered_model_tag(rm1.name, None)
+ with self.assertRaises(SubmarineException):
+ self.store.delete_registered_model_tag(rm1.name, "")
+
+ # Cannot use invalid model name
+ with self.assertRaises(SubmarineException):
+ self.store.delete_registered_model_tag(None, "tag1")
+
+ # Cannot delete tag on deleted (non-existed) registered model
+ self.store.delete_registered_model(name1)
+ with self.assertRaises(SubmarineException):
+ self.store.delete_registered_model_tag(name1, "tag1")
+
+ @freeze_time("2021-11-11 11:11:11.111000")
+ def test_create_model_version(self):
+ model_name = "test_create_MV"
+ self.store.create_registered_model(model_name)
+ fake_datetime = datetime.now()
+ mv1 = self.store.create_model_version(
+ model_name, "path/to/source1", "test", "application_1234"
+ )
+ self.assertEqual(mv1.name, model_name)
+ self.assertEqual(mv1.version, 1)
+ self.assertEqual(mv1.creation_time, fake_datetime)
+
+ m1d = self.store.get_model_version(mv1.name, mv1.version)
+ self.assertEqual(m1d.name, model_name)
+ self.assertEqual(m1d.user_id, "test")
+ self.assertEqual(m1d.experiment_id, "application_1234")
+ self.assertEqual(m1d.current_stage, STAGE_NONE)
+ self.assertEqual(m1d.creation_time, fake_datetime)
+ self.assertEqual(m1d.last_updated_time, fake_datetime)
+ self.assertEqual(m1d.source, "path/to/source1")
+ self.assertEqual(m1d.dataset, None)
+
+ # new model for same registered model autoincrement version
+ m2 = self.store.create_model_version(
+ model_name, "path/to/source2", "test", "application_1234"
+ )
+ m2d = self.store.get_model_version(m2.name, m2.version)
+ self.assertEqual(m2.version, 2)
+ self.assertEqual(m2d.version, 2)
+
+ # create model with tags
+ tags = ["tag1", "tag2"]
+ m3 = self.store.create_model_version(
+ model_name, "path/to/source3", "test", "application_1234", tags=tags
+ )
+ m3d = self.store.get_model_version(m3.name, m3.version)
+ self.assertEqual(m3.version, 3)
+ self.assertEqual(m3.tags, tags)
+ self.assertEqual(m3d.version, 3)
+ self.assertEqual(m3d.tags, tags)
+
+ # create model with description
+ description = "A test description."
+ m4 = self.store.create_model_version(
+ model_name, "path/to/source4", "test", "application_1234", description=description
+ )
+ m4d = self.store.get_model_version(m4.name, m4.version)
+ self.assertEqual(m4.version, 4)
+ self.assertEqual(m4.description, description)
+ self.assertEqual(m4d.version, 4)
+ self.assertEqual(m4d.description, description)
+
+ def test_update_model_version_description(self):
+ name = "test_update_MV_description"
+ self.store.create_registered_model(name)
+ mv1 = self.store.create_model_version(name, "path/to/source", "test", "application_1234")
+ m1d = self.store.get_model_version(mv1.name, mv1.version)
+ self.assertEqual(m1d.name, name)
+ self.assertEqual(m1d.version, 1)
+ self.assertEqual(m1d.description, None)
+
+ # update description
+ fake_datetime = datetime.strptime("2021-11-11 11:11:11.111000", "%Y-%m-%d %H:%M:%S.%f")
+ with freeze_time(fake_datetime):
+ self.store.update_model_version_description(mv1.name, mv1.version, "New description.")
+ m1d = self.store.get_model_version(mv1.name, mv1.version)
+ self.assertEqual(m1d.name, name)
+ self.assertEqual(m1d.version, 1)
+ self.assertEqual(m1d.description, "New description.")
+ self.assertEqual(m1d.last_updated_time, fake_datetime)
+
+ def test_transition_model_version_stage(self):
+ name = "test_transition_MV_stage"
+ self.store.create_registered_model(name)
+ mv1 = self.store.create_model_version(name, "path/to/source1", "test", "application_1234")
+ m2 = self.store.create_model_version(name, "path/to/source2", "test", "application_1234")
+
+ fake_datetime = datetime.strptime("2021-11-11 11:11:11.111000", "%Y-%m-%d %H:%M:%S.%f")
+ with freeze_time(fake_datetime):
+ self.store.transition_model_version_stage(mv1.name, mv1.version, STAGE_DEVELOPING)
+ m1d = self.store.get_model_version(mv1.name, mv1.version)
+ self.assertEqual(m1d.current_stage, STAGE_DEVELOPING)
+
+ # check last updated time
+ self.assertEqual(m1d.last_updated_time, fake_datetime)
+ rmd = self.store.get_registered_model(name)
+ self.assertEqual(rmd.last_updated_time, fake_datetime)
+
+ fake_datetime = datetime.strptime("2021-11-11 11:11:22.222000", "%Y-%m-%d %H:%M:%S.%f")
+ with freeze_time(fake_datetime):
+ self.store.transition_model_version_stage(mv1.name, mv1.version, STAGE_PRODUCTION)
+ m1d = self.store.get_model_version(mv1.name, mv1.version)
+ self.assertEqual(m1d.current_stage, STAGE_PRODUCTION)
+
+ # check last updated time
+ self.assertEqual(m1d.last_updated_time, fake_datetime)
+ rmd = self.store.get_registered_model(name)
+ self.assertEqual(rmd.last_updated_time, fake_datetime)
+
+ fake_datetime = datetime.strptime("2021-11-11 11:11:22.333000", "%Y-%m-%d %H:%M:%S.%f")
+ with freeze_time(fake_datetime):
+ self.store.transition_model_version_stage(mv1.name, mv1.version, STAGE_ARCHIVED)
+ m1d = self.store.get_model_version(mv1.name, mv1.version)
+ self.assertEqual(m1d.current_stage, STAGE_ARCHIVED)
+
+ # check last updated time
+ self.assertEqual(m1d.last_updated_time, fake_datetime)
+ rmd = self.store.get_registered_model(name)
+ self.assertEqual(rmd.last_updated_time, fake_datetime)
+
+ # uncanonical stage
+ for uncanonical_stage_name in ["DEVELOPING", "developing", "DevElopIng"]:
+ self.store.transition_model_version_stage(mv1.name, mv1.version, STAGE_NONE)
+ self.store.transition_model_version_stage(mv1.name, mv1.version, uncanonical_stage_name)
+
+ m1d = self.store.get_model_version(mv1.name, mv1.version)
+ self.assertEqual(m1d.current_stage, STAGE_DEVELOPING)
+
+ # Not matching stages
+ with self.assertRaises(SubmarineException):
+ self.store.transition_model_version_stage(mv1.name, mv1.version, None)
+ # Not matching stages
+ with self.assertRaises(SubmarineException):
+ self.store.transition_model_version_stage(mv1.name, mv1.version, "stage")
+
+ # No change for other model
+ m2d = self.store.get_model_version(m2.name, m2.version)
+ self.assertEqual(m2d.current_stage, STAGE_NONE)
+
+ def test_delete_model_version(self):
+ name = "test_for_delete_MV"
+ tags = ["tag1", "tag2"]
+ self.store.create_registered_model(name)
+ mv = self.store.create_model_version(
+ name, "path/to/source", "test", "application_1234", tags=tags
+ )
+ mvd = self.store.get_model_version(mv.name, mv.version)
+ self.assertEqual(mvd.name, name)
+
+ self.store.delete_model_version(name=mv.name, version=mv.version)
+
+ # model tags are cascade deleted with the model
+ with self.assertRaises(SubmarineException):
+ self.store.delete_model_version_tag(mv.name, mv.version, tags[0])
+ with self.assertRaises(SubmarineException):
+ self.store.delete_model_version_tag(mv.name, mv.version, tags[1])
+
+ # cannot get a deleted model
+ with self.assertRaises(SubmarineException):
+ self.store.get_model_version(mv.name, mv.version)
+
+ # cannot update description of a deleted model
+ with self.assertRaises(SubmarineException):
+ self.store.update_model_version_description(mv.name, mv.version, "New description.")
+
+ # cannot delete a non-existing version
+ with self.assertRaises(SubmarineException):
+ self.store.delete_model_version(name=mv.name, version=None)
+
+ # cannot delete a non-existing model name
+ with self.assertRaises(SubmarineException):
+ self.store.delete_model_version(name=None, version=mv.version)
+
+ @freeze_time("2021-11-11 11:11:11.111000")
+ def test_get_model_version(self):
+ name = "test_get_MV"
+ tags = ["tag1", "tag2"]
+ self.store.create_registered_model(name)
+ fake_datetime = datetime.now()
+ mv = self.store.create_model_version(
+ name,
+ source="path/to/source",
+ user_id="test",
+ experiment_id="application_1234",
+ tags=tags,
+ )
+ self.assertEqual(mv.creation_time, fake_datetime)
+ self.assertEqual(mv.last_updated_time, fake_datetime)
+ mvd = self.store.get_model_version(mv.name, mv.version)
+ self.assertEqual(mvd.name, name)
+ self.assertEqual(mvd.user_id, "test")
+ self.assertEqual(mvd.experiment_id, "application_1234")
+ self.assertEqual(mvd.current_stage, STAGE_NONE)
+ self.assertEqual(mvd.creation_time, fake_datetime)
+ self.assertEqual(mvd.last_updated_time, fake_datetime)
+ self.assertEqual(mvd.source, "path/to/source")
+ self.assertEqual(mvd.dataset, None)
+ self.assertEqual(mvd.description, None)
+ self.assertEqual(mvd.tags, tags)
+
+ def _compare_model_versions(self, results: List[ModelVersion], mms: List[ModelVersion]) -> None:
+ result_versions = set([result.version for result in results])
+ model_versions = set([mm.version for mm in mms])
+
+ self.assertEqual(result_versions, model_versions)
+
+ @freeze_time("2021-11-11 11:11:11.111000")
+ def test_list_model_versions(self):
+ name1 = "test_list_models_1"
+ name2 = "test_list_models_2"
+ self.store.create_registered_model(name1)
+ self.store.create_registered_model(name2)
+ tags = ["tag1", "tag2", "tag3"]
+ models = [
+ self.store.create_model_version(name1, "path/to/source1", "test", "application_1234"),
+ self.store.create_model_version(
+ name1, "path/to/source2", "test", "application_1234", tags=[tags[0]]
+ ),
+ self.store.create_model_version(
+ name1, "path/to/source3", "test", "application_1234", tags=[tags[1]]
+ ),
+ self.store.create_model_version(
+ name1, "path/to/source4", "test", "application_1234", tags=[tags[0], tags[2]]
+ ),
+ self.store.create_model_version(
+ name1, "path/to/source5", "test", "application_1234", tags=tags
+ ),
+ ]
+
+ results = self.store.list_model_versions(name1)
+ self.assertEqual(len(results), 5)
+ self._compare_model_versions(results, models)
+
+ results = self.store.list_model_versions(name1, filter_tags=tags[0:1])
+ self.assertEqual(len(results), 3)
+ self._compare_model_versions(results, [models[1], models[3], models[4]])
+
+ results = self.store.list_model_versions(name1, filter_tags=tags[0:2])
+ self.assertEqual(len(results), 1)
+ self._compare_model_versions(results, [models[-1]])
+
+ # empty result
+ other_tag = ["tag4"]
+ results = self.store.list_model_versions(name1, filter_tags=other_tag)
+ self.assertEqual(len(results), 0)
+
+ # empty result
+ results = self.store.list_model_versions(name1, filter_tags=tags + other_tag)
+ self.assertEqual(len(results), 0)
+
+ # empty result for other models
+ results = self.store.list_model_versions(name2)
+ self.assertEqual(len(results), 0)
+ results = self.store.list_model_versions(name2, filter_tags=tags)
+ self.assertEqual(len(results), 0)
+
+ def test_get_model_version_uri(self):
+ name = "test_get_model_version_uri"
+ self.store.create_registered_model(name)
+ mv = self.store.create_model_version(name, "path/to/source", "test", "application_1234")
+ uri = self.store.get_model_version_uri(mv.name, mv.version)
+ self.assertEqual(uri, "path/to/source")
+
+ # uri does not change even if model version is updated
+ self.store.transition_model_version_stage(mv.name, mv.version, STAGE_PRODUCTION)
+ self.store.update_model_version_description(mv.name, mv.version, "New description.")
+ uri = self.store.get_model_version_uri(mv.name, mv.version)
+ self.assertEqual(uri, "path/to/source")
+
+ # cannot retrieve URI for deleted model version
+ self.store.delete_model_version(mv.name, mv.version)
+ with self.assertRaises(SubmarineException):
+ self.store.get_model_version_uri(mv.name, mv.version)
+
+ def test_add_model_version_tag(self):
+ name1 = "test_add_MV_tag"
+ name2 = "test_add_MV_tag_2"
+ tags = ["tag1", "tag2"]
+ self.store.create_registered_model(name1)
+ self.store.create_registered_model(name2)
+ rm1mv1 = self.store.create_model_version(
+ name1, "path/to/source1", "test", "application_1234", tags=tags
+ )
+ rm1m2 = self.store.create_model_version(
+ name1, "path/to/source2", "test", "application_1234", tags=tags
+ )
+ rm2mv1 = self.store.create_model_version(
+ name2, "path/to/source3", "test", "application_1234", tags=tags
+ )
+ new_tag = "new tag"
+ self.store.add_model_version_tag(rm1mv1.name, rm1mv1.version, new_tag)
+ all_tags = [new_tag] + tags
+ rm1m1d = self.store.get_model_version(rm1mv1.name, rm1mv1.version)
+ self.assertEqual(rm1m1d.name, name1)
+ self.assertEqual(rm1m1d.tags, all_tags)
+
+ # test add a same tag
+ same_tag = "tag1"
+ self.store.add_model_version_tag(rm1mv1.name, rm1mv1.version, same_tag)
+ mvd = self.store.get_model_version(rm1mv1.name, rm1mv1.version)
+ self.assertEqual(mvd.tags, all_tags)
+
+ # does not affect other models
+ rm1m2d = self.store.get_model_version(rm1m2.name, rm1m2.version)
+ self.assertEqual(rm1m2d.name, name1)
+ self.assertEqual(rm1m2d.tags, tags)
+ rm2mv1 = self.store.get_model_version(rm2mv1.name, rm2mv1.version)
+ self.assertEqual(rm2mv1.name, name2)
+ self.assertEqual(rm2mv1.tags, tags)
+
+ # cannot add an invalid tag
+ with self.assertRaises(SubmarineException):
+ self.store.add_model_version_tag(rm1mv1.name, rm1mv1.version, None)
+ with self.assertRaises(SubmarineException):
+ self.store.add_model_version_tag(rm1mv1.name, rm1mv1.version, "")
+
+ # cannot add tag on deleted (non-existed) model
+ self.store.delete_model_version(rm1mv1.name, rm1mv1.version)
+ with self.assertRaises(SubmarineException):
+ self.store.add_model_version_tag(rm1mv1.name, rm1mv1.version, same_tag)
+
+ def test_delete_model_tag(self):
+ name1 = "test_delete_MV_tag"
+ name2 = "test_delete_MV_tag_2"
+ tags = ["tag1", "tag2"]
+ self.store.create_registered_model(name1)
+ self.store.create_registered_model(name2)
+ rm1mv1 = self.store.create_model_version(
+ name1, "path/to/source1", "test", "application_1234", tags=tags
+ )
+ rm1m2 = self.store.create_model_version(
+ name1, "path/to/source2", "test", "application_1234", tags=tags
+ )
+ rm2mv1 = self.store.create_model_version(
+ name2, "path/to/source3", "test", "application_1234", tags=tags
+ )
+ new_tag = "new tag"
+ self.store.add_model_version_tag(rm1mv1.name, rm1mv1.version, new_tag)
+ self.store.delete_model_version_tag(rm1mv1.name, rm1mv1.version, new_tag)
+ rm1m1d = self.store.get_model_version(rm1mv1.name, rm1mv1.version)
+ self.assertEqual(rm1m1d.tags, tags)
+
+ # deleting a tag does not affect other models
+ self.store.delete_model_version_tag(rm1mv1.name, rm1mv1.version, tags[0])
+ rm1m1d = self.store.get_model_version(rm1mv1.name, rm1mv1.version)
+ rm1m2d = self.store.get_model_version(rm1m2.name, rm1m2.version)
+ rm2mv1 = self.store.get_model_version(rm2mv1.name, rm2mv1.version)
+ self.assertEqual(rm1m1d.tags, tags[1:])
+ self.assertEqual(rm1m2d.tags, tags)
+ self.assertEqual(rm2mv1.tags, tags)
+
+ # delete a tag that is already deleted
+ with self.assertRaises(SubmarineException):
+ self.store.delete_model_version_tag(rm1mv1.name, rm1mv1.version, tags[0])
+ rm1m1d = self.store.get_model_version(rm1mv1.name, rm1mv1.version)
+ self.assertEqual(rm1m1d.tags, tags[1:])
+
+ # cannot delete tag with invalid value
+ with self.assertRaises(SubmarineException):
+ self.store.delete_model_version_tag(rm1mv1.name, rm1mv1.version, None)
+ with self.assertRaises(SubmarineException):
+ self.store.delete_model_version_tag(rm1mv1.name, rm1mv1.version, "")
+
+ # cannot delete tag on deleted (non-existed) model
+ self.store.delete_model_version(rm1m2.name, rm1m2.version)
+ with self.assertRaises(SubmarineException):
+ self.store.delete_model_version_tag(rm1m2.name, rm1m2.version, tags[0])
+
+ # cannot use invalid model name or version
+ with self.assertRaises(SubmarineException):
+ self.store.delete_model_version_tag(None, rm1mv1.version, tags[1])
+ with self.assertRaises(SubmarineException):
+ self.store.delete_model_version_tag(rm1mv1.name, None, tags[1])
diff --git a/dev-support/style-check/python/mypy-requirements.txt b/submarine-sdk/pysubmarine/tests/store/tracking/__init__.py
similarity index 87%
copy from dev-support/style-check/python/mypy-requirements.txt
copy to submarine-sdk/pysubmarine/tests/store/tracking/__init__.py
index 6a581e3..a6eb1b5 100644
--- a/dev-support/style-check/python/mypy-requirements.txt
+++ b/submarine-sdk/pysubmarine/tests/store/tracking/__init__.py
@@ -12,9 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-mypy==0.910
-types-requests==2.25.6
-types-certifi==2020.4.0
-types-six==1.16.1
-types-python-dateutil==2.8.0
diff --git a/submarine-sdk/pysubmarine/tests/store/test_sqlalchemy_store.py b/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py
similarity index 90%
rename from submarine-sdk/pysubmarine/tests/store/test_sqlalchemy_store.py
rename to submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py
index d7597d8..dbdacdc 100644
--- a/submarine-sdk/pysubmarine/tests/store/test_sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/tests/store/tracking/test_sqlalchemy_store.py
@@ -22,7 +22,7 @@ import submarine
from submarine.entities import Metric, Param
from submarine.store.database import models
from submarine.store.database.models import SqlExperiment, SqlMetric, SqlParam
-from submarine.tracking import utils
+from submarine.store.sqlalchemy_store import SqlAlchemyStore
JOB_ID = "application_123456789"
@@ -30,12 +30,12 @@ JOB_ID = "application_123456789"
@pytest.mark.e2e
class TestSqlAlchemyStore(unittest.TestCase):
def setUp(self):
- submarine.set_tracking_uri(
+ submarine.set_db_uri(
"mysql+pymysql://submarine_test:password_test@localhost:3306/submarine_test"
)
- self.tracking_uri = utils.get_tracking_uri()
- self.store = utils.get_sqlalchemy_store(self.tracking_uri)
- # TODO: use submarine.tracking.fluent to support experiment create
+ self.db_uri = submarine.get_db_uri()
+ self.store = SqlAlchemyStore(self.db_uri)
+ # TODO(KUAN-HSUN-LI): use submarine.tracking.fluent to support experiment create
with self.store.ManagedSessionMaker() as session:
instance = SqlExperiment(
id=JOB_ID,
@@ -49,7 +49,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
session.commit()
def tearDown(self):
- submarine.set_tracking_uri(None)
+ submarine.set_db_uri(None)
models.Base.metadata.drop_all(self.store.engine)
def test_log_param(self):
diff --git a/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py b/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
index e167a1c..7410e16 100644
--- a/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
+++ b/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
@@ -22,7 +22,7 @@ import pytest
import submarine
from submarine.store.database import models
from submarine.store.database.models import SqlExperiment, SqlMetric, SqlParam
-from submarine.tracking import utils
+from submarine.store.sqlalchemy_store import SqlAlchemyStore
JOB_ID = "application_123456789"
@@ -31,11 +31,11 @@ JOB_ID = "application_123456789"
class TestTracking(unittest.TestCase):
def setUp(self):
environ["JOB_ID"] = JOB_ID
- submarine.set_tracking_uri(
+ submarine.set_db_uri(
"mysql+pymysql://submarine_test:password_test@localhost:3306/submarine_test"
)
- self.tracking_uri = utils.get_tracking_uri()
- self.store = utils.get_sqlalchemy_store(self.tracking_uri)
+ self.db_uri = submarine.get_db_uri()
+ self.store = SqlAlchemyStore(self.db_uri)
# TODO: use submarine.tracking.fluent to support experiment create
with self.store.ManagedSessionMaker() as session:
instance = SqlExperiment(
@@ -50,7 +50,7 @@ class TestTracking(unittest.TestCase):
session.commit()
def tearDown(self):
- submarine.set_tracking_uri(None)
+ submarine.set_db_uri(None)
models.Base.metadata.drop_all(self.store.engine)
def test_log_param(self):
diff --git a/submarine-sdk/pysubmarine/tests/tracking/test_utils.py b/submarine-sdk/pysubmarine/tests/tracking/test_utils.py
index d3e20e2..2fc3392 100644
--- a/submarine-sdk/pysubmarine/tests/tracking/test_utils.py
+++ b/submarine-sdk/pysubmarine/tests/tracking/test_utils.py
@@ -24,27 +24,9 @@ from submarine.tracking.utils import (
_TRACKING_URI_ENV_VAR,
get_job_id,
get_sqlalchemy_store,
- get_tracking_uri,
- is_tracking_uri_set,
)
-def test_is_tracking_uri_set():
- env = {
- _TRACKING_URI_ENV_VAR: DEFAULT_SUBMARINE_JDBC_URL,
- }
- with mock.patch.dict(os.environ, env):
- assert is_tracking_uri_set() is True
-
-
-def test_get_tracking_uri():
- env = {
- _TRACKING_URI_ENV_VAR: DEFAULT_SUBMARINE_JDBC_URL,
- }
- with mock.patch.dict(os.environ, env):
- assert get_tracking_uri() == DEFAULT_SUBMARINE_JDBC_URL
-
-
def test_get_job_id():
env = {
_JOB_ID_ENV_VAR: "application_12346789",
diff --git a/submarine-sdk/pysubmarine/submarine/utils/__init__.py b/submarine-sdk/pysubmarine/tests/utils/test_db_utils.py
similarity index 50%
copy from submarine-sdk/pysubmarine/submarine/utils/__init__.py
copy to submarine-sdk/pysubmarine/tests/utils/test_db_utils.py
index 6f2b95c..2565082 100644
--- a/submarine-sdk/pysubmarine/submarine/utils/__init__.py
+++ b/submarine-sdk/pysubmarine/tests/utils/test_db_utils.py
@@ -13,25 +13,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six.moves import urllib
+import os
-from submarine.exceptions import SubmarineException
+import mock
+from submarine.store import DEFAULT_SUBMARINE_JDBC_URL
+from submarine.utils import get_db_uri, set_db_uri
+from submarine.utils.db_utils import _DB_URI_ENV_VAR, is_db_uri_set
-def extract_db_type_from_uri(db_uri):
- """
- Parse the specified DB URI to extract the database type. Confirm the database type is
- supported. If a driver is specified, confirm it passes a plausible regex.
- """
- scheme = urllib.parse.urlparse(db_uri).scheme
- scheme_plus_count = scheme.count("+")
- if scheme_plus_count == 0:
- db_type = scheme
- elif scheme_plus_count == 1:
- db_type, _ = scheme.split("+")
- else:
- error_msg = "Invalid database URI: '%s'. %s" % (db_uri, "INVALID_DB_URI_MSG")
- raise SubmarineException(error_msg)
+def test_is_db_uri_set():
+ env = {
+ _DB_URI_ENV_VAR: DEFAULT_SUBMARINE_JDBC_URL,
+ }
+ with mock.patch.dict(os.environ, env):
+ assert is_db_uri_set() is True
- return db_type
+
+def test_set_db_uri():
+ test_db_uri = "mysql+pymysql://submarine:password@localhost:3306/submarine_test"
+ set_db_uri(test_db_uri)
+ assert get_db_uri() == test_db_uri
+ set_db_uri(None)
+
+
+def test_get_db_uri():
+ env = {
+ _DB_URI_ENV_VAR: DEFAULT_SUBMARINE_JDBC_URL,
+ }
+ with mock.patch.dict(os.environ, env):
+ assert get_db_uri() == DEFAULT_SUBMARINE_JDBC_URL
---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org