You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@submarine.apache.org by GitBox <gi...@apache.org> on 2021/09/16 14:55:28 UTC

[GitHub] [submarine] KUAN-HSUN-LI opened a new pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

KUAN-HSUN-LI opened a new pull request #752:
URL: https://github.com/apache/submarine/pull/752


   ### What is this PR for?
   * Implement the model registry SQL method in Python SDK
   * Apply sqlalchemy mypy checks
   * Replace submarine tracking_uri with db_uri
   
   ### What type of PR is it?
   [Feature]
   
   ### Todos
   
   
   ### What is the Jira issue?
   https://issues.apache.org/jira/browse/SUBMARINE-1023
   
   ### How should this be tested?
   All of the tests are provided in `submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py`
   
   ### Screenshots (if appropriate)
   
   ### Questions:
   * Do the license files need updating? No
   * Are there breaking changes for older versions? No
   * Does this need new documentation? No
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] jeff-901 commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
jeff-901 commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r712987558



##########
File path: submarine-sdk/pysubmarine/submarine/utils/validation.py
##########
@@ -116,8 +117,48 @@ def validate_param(key, value):
     _validate_length_limit("Param value", MAX_PARAM_VAL_LENGTH, str(value))
 
 
+def validate_tags(tags: Optional[List[str]]) -> None:
+    if tags is not None and not isinstance(tags, list):
+        raise SubmarineException("parameter tags must be list or None.")
+    for tag in tags or []:
+        validate_tag(tag)
+
+
+def validate_tag(tag: str) -> None:
+    """Check that `tag` is a valid tag value and raise an exception if it isn't."""
+    # Reuse param & metric check.
+    if tag is None or tag == "":
+        raise SubmarineException("Tag cannot be empty.")
+    if not _VALID_PARAM_AND_METRIC_NAMES.match(tag):
+        raise SubmarineException("Invalid tag name: '%s'. %s" % (tag, _BAD_CHARACTERS_MESSAGE))
+
+
+def validate_model_name(model_name: str) -> None:
+    if model_name is None or model_name == "":
+        raise SubmarineException("Registered model name cannot be empty.")
+
+
+def validate_model_version(model_version: int) -> None:

Review comment:
       model_version forget to change to model metadata




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] jeff-901 commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
jeff-901 commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r712989103



##########
File path: submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py
##########
@@ -15,12 +15,11 @@
 
 from datetime import datetime

Review comment:
       This file name should be changed to test_model_metadata.py




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] KUAN-HSUN-LI commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
KUAN-HSUN-LI commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r711603580



##########
File path: submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
##########
@@ -0,0 +1,570 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from contextlib import contextmanager
+from datetime import datetime
+from typing import List, Union
+
+import sqlalchemy
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.orm.session import Session, sessionmaker
+from sqlalchemy.orm.strategy_options import _UnboundLoad
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+from submarine.entities.model_registry.model_version_stages import (
+    STAGE_DELETED_INTERNAL,
+    get_canonical_stage,
+)
+from submarine.exceptions import SubmarineException
+from submarine.store.database.models import (
+    Base,
+    SqlModelTag,
+    SqlModelVersion,
+    SqlRegisteredModel,
+    SqlRegisteredModelTag,
+)
+from submarine.store.model_registry.abstract_store import AbstractStore
+from submarine.utils import extract_db_type_from_uri
+from submarine.utils.validation import (
+    validate_model_name,
+    validate_model_version,
+    validate_tag,
+    validate_tags,
+)
+
+_logger = logging.getLogger(__name__)
+
+
+class SqlAlchemyStore(AbstractStore):
+    def __init__(self, db_uri: str) -> None:
+        """
+        Create a database backed store.
+        :param db_uri: The SQLAlchemy database URI string to connect to the database. See
+                       the `SQLAlchemy docs
+                       <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
+                       for format specifications. Submarine supports the dialects ``mysql``.
+        """
+        super(SqlAlchemyStore, self).__init__()
+
+        self.db_uri = db_uri
+        self.db_type = extract_db_type_from_uri(db_uri)
+        self.engine = sqlalchemy.create_engine(db_uri, pool_pre_ping=True)
+        insp = sqlalchemy.inspect(self.engine)
+
+        # Verify that all model registry tables exist.
+        expected_tables = {
+            SqlModelVersion.__tablename__,
+            SqlModelTag.__tablename__,
+            SqlRegisteredModel.__tablename__,
+            SqlRegisteredModelTag.__tablename__,
+        }
+        if len(expected_tables & set(insp.get_table_names())) == 0:
+            SqlAlchemyStore._initialize_tables(self.engine)
+        Base.metadata.bind = self.engine
+        SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
+        self.ManagedSessionMaker = self._get_managed_session_maker(SessionMaker)
+
+    @staticmethod
+    def _initialize_tables(engine: Engine):
+        _logger.info("Creating initial Submarine database tables...")
+        Base.metadata.create_all(engine)
+
+    @staticmethod
+    def _get_managed_session_maker(SessionMaker: sessionmaker):
+        """
+        Creates a factory for producing exception-safe SQLAlchemy sessions that are made available
+        using a context manager. Any session produced by this factory is automatically committed
+        if no exceptions are encountered within its associated context. If an exception is
+        encountered, the session is rolled back. Finally, any session produced by this factory is
+        automatically closed when the session's associated context is exited.
+        """
+
+        @contextmanager
+        def make_managed_session():
+            """Provide a transactional scope around a series of operations."""
+            session: Session = SessionMaker()
+            try:
+                yield session
+                session.commit()
+            except SubmarineException:
+                session.rollback()
+                raise
+            except Exception as e:
+                session.rollback()
+                raise SubmarineException(message=e)
+            finally:
+                session.close()
+
+        return make_managed_session
+
+    @staticmethod
+    def _get_eager_registered_model_query_options() -> List[_UnboundLoad]:
+        """
+        :return A list of SQLAlchemy query options that can be used to eagerly
+                load the following registered model attributes
+                when fetching a registered model: ``registered_model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlRegisteredModel.registered_model_tags)]
+
+    @staticmethod
+    def _get_eager_model_version_query_options():
+        """
+        :return: A list of SQLAlchemy query options that can be used to eagerly
+                load the following model version attributes
+                when fetching a model version: ``model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlModelVersion.model_tags)]
+
+    def _save_to_db(self, session: Session, objs: Union[list, object]) -> None:
+        """
+        Store in db
+        """
+        if type(objs) is list:
+            session.add_all(objs)
+        else:
+            # single object
+            session.add(objs)
+
+    def create_registered_model(
+        self, name: str, description: str = None, tags: List[str] = None
+    ) -> RegisteredModel:
+        """
+        Create a new registered model in backend store.
+        :param name: Name of the new model. This is expected to be unique in the backend store.
+        :param description: Description of the model.
+        :param tags: A list of string associated with this registered model.
+        :return: A single object of :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 created in the backend.
+        """
+        validate_model_name(name)
+        validate_tags(tags)
+
+        with self.ManagedSessionMaker() as session:
+            try:
+                creation_time = datetime.now()
+                registered_model = SqlRegisteredModel(
+                    name=name,
+                    creation_time=creation_time,
+                    last_updated_time=creation_time,
+                    description=description,
+                    registered_model_tags=[SqlRegisteredModelTag(tag=tag) for tag in tags or []],
+                )
+                self._save_to_db(session, registered_model)
+                session.flush()
+                return registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    message=f"Registered Model (name={name}) already exists.\nError: {str(e)}"
+                )
+
+    @classmethod
+    def _get_registered_model(
+        cls, session: Session, name: str, eager: bool = False
+    ) -> SqlRegisteredModel:
+        """
+        :param eager: If ``True``, eagerly loads the registered model's tags.
+                      If ``False``, these attributes are not eagerly loaded and
+                      will be loaded when their corresponding object properties
+                      are accessed from the resulting ``SqlRegisteredModel`` object.
+        """
+        validate_model_name(name)
+        query_options = cls._get_eager_registered_model_query_options() if eager else []
+        models: List[SqlRegisteredModel] = (
+            session.query(SqlRegisteredModel)
+            .options(*query_options)
+            .filter(SqlRegisteredModel.name == name)
+            .all()
+        )
+
+        if len(models) == 0:
+            raise SubmarineException(f"Registered Model with name={name} not found")
+        elif len(models) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model with name={name}.\nFound {len(models)}"
+            )
+        else:
+            return models[0]
+
+    def update_registered_model_discription(self, name: str, description: str) -> RegisteredModel:
+        """
+        Update description of the registered model.
+        :param name: Registered model name.
+        :param description: New description.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            sql_registered_model.description = description
+            sql_registered_model.last_updated_time = datetime.now()
+            self._save_to_db(session, sql_registered_model)
+            session.flush()
+            return sql_registered_model.to_submarine_entity()
+
+    def rename_registered_model(self, name: str, new_name: str) -> RegisteredModel:
+        """
+        Rename the registered model.
+        :param name: Registered model name.
+        :param new_name: New proposed name.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        validate_model_name(new_name)
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            try:
+                update_time = datetime.now()
+                sql_registered_model.name = new_name
+                sql_registered_model.last_updated_time = update_time
+                for sql_model_version in sql_registered_model.model_versions:
+                    sql_model_version.name = new_name
+                    sql_model_version.last_updated_time = update_time
+                self._save_to_db(
+                    session, [sql_registered_model] + sql_registered_model.model_versions
+                )
+                session.flush()
+                return sql_registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    message=f"Registered Model (name={name}) already exists. Error: {str(e)}"
+                )
+
+    def delete_registered_model(self, name: str) -> None:
+        """
+        Delete the registered model.
+        :param name: Registered model name.
+        :return: None
+        """
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            session.delete(sql_registered_model)
+
+    def list_registered_model(
+        self, filter_str: str = None, filter_tags: List[str] = None
+    ) -> List[RegisteredModel]:
+        """
+        List of all registered models.
+        :param filter_string: Filter query string, defaults to searching all registered models.
+        :param filter_tags: Filter tags, defaults not to filter any tags.
+        :return: A List of :py:class:`submarine.entities.model_registry.RegisteredModel` objects
+                that satisfy the search expressions.
+        """
+        conditions = []
+        if filter_tags is not None:
+            conditions = [
+                SqlRegisteredModel.registered_model_tags.any(
+                    SqlRegisteredModelTag.tag.contains(tag)
+                )
+                for tag in filter_tags
+            ]
+        if filter_str is not None:
+            conditions.append(SqlRegisteredModel.name.startswith(filter_str))
+        with self.ManagedSessionMaker() as session:
+            registered_models = session.query(SqlRegisteredModel).filter(*conditions).all()
+            return [
+                registered_model.to_submarine_entity() for registered_model in registered_models
+            ]
+
+    def get_registered_model(self, name: str) -> RegisteredModel:
+        """
+        Get registered model instance by name.
+        :param name: Registered model name.
+        :return: A single :py:class:`submarine.entities.model_registry.RegisteredModel` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            return self._get_registered_model(session, name, True).to_submarine_entity()
+
+    @classmethod
+    def _get_registered_model_tag(
+        cls, session: Session, name: str, tag: str
+    ) -> SqlRegisteredModelTag:
+        tags = (
+            session.query(SqlRegisteredModelTag)
+            .filter(SqlRegisteredModelTag.name == name, SqlRegisteredModelTag.tag == tag)
+            .all()
+        )
+        if len(tags) == 0:
+            raise SubmarineException(
+                message=f"Registered model tag with name={name}, tag={tag} not found"
+            )
+        elif len(tags) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model version tag with name={name}, tag={tag}. Found"
+                f" {len(tags)}."
+            )
+        else:
+            return tags[0]
+
+    def add_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Add a tag for the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            session.merge(SqlRegisteredModelTag(name=name, tag=tag))
+
+    def delete_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Delete a tag associated with the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            existing_tag = self._get_registered_model_tag(session, name, tag)
+            session.delete(existing_tag)
+
+    def create_model_version(
+        self,
+        name: str,
+        source: str,
+        user_id: str,
+        experiment_id: str,
+        dataset: str = None,
+        description: str = None,
+        tags: List[str] = None,
+    ) -> ModelVersion:
+        """
+        Create a new model version
+        :param name: Registered model name.
+        :param user_id: User ID from server that created this model
+        :param experiment_id: Experiment ID which this model is created.
+        :param source: Source path where this model is stored.
+        :param dataset: Dataset which this model is used.
+        :param description: Description of the version.
+        :param tags: A list of string associated with this model.
+        :return: A single object of :py:class:`submarine.entities.model_registry.ModelVersion`
+                 created in the backend.
+        """
+
+        def next_version(sql_registered_model: SqlRegisteredModel) -> int:
+            if sql_registered_model.model_versions:
+                return max([mv.version for mv in sql_registered_model.model_versions]) + 1
+            else:
+                return 1
+
+        validate_model_name(name)
+        validate_tags(tags)
+        with self.ManagedSessionMaker() as session:
+            try:
+                creation_time = datetime.now()
+                sql_registered_model = self._get_registered_model(session, name)
+                sql_registered_model.last_updated_time = creation_time
+                model_version = SqlModelVersion(
+                    name=name,
+                    version=next_version(sql_registered_model),
+                    source=source,
+                    user_id=user_id,
+                    experiment_id=experiment_id,
+                    creation_time=creation_time,
+                    last_updated_time=creation_time,
+                    dataset=dataset,
+                    description=description,
+                    model_tags=[SqlModelTag(tag=tag) for tag in tags or []],
+                )
+                self._save_to_db(session, [sql_registered_model, model_version])
+                session.flush()
+                return model_version.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError:
+                raise SubmarineException(message=f"Model Version creation error (name={name}).")
+
+    @classmethod
+    def _get_model_version(
+        cls, session: Session, name: str, version: int, eager: bool = False
+    ) -> SqlModelVersion:
+        """
+        :param eager: If ``True``, eagerly loads the model version's tags.
+                      If ``False``, these attributes are not eagerly loaded and
+                      will be loaded when their corresponding object properties
+                      are accessed from the resulting ``SqlModelVersion`` object.
+        """
+        validate_model_name(name)
+        validate_model_version(version)
+        query_options = cls._get_eager_model_version_query_options() if eager else []
+        conditions = [
+            SqlModelVersion.name == name,
+            SqlModelVersion.version == version,
+            SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL,
+        ]
+
+        versions: List[SqlModelVersion] = (
+            session.query(SqlModelVersion).options(*query_options).filter(*conditions).all()
+        )
+        if len(versions) == 0:
+            raise SubmarineException(f"Model Version (name={name}, version={version}) not found.")
+        elif len(versions) > 1:
+            raise SubmarineException(
+                f"Expected only 1 model version with (name={name}, version={versions}). Found"
+                f" {len(versions)}."
+            )
+        else:
+            return versions[0]
+
+    def update_model_version_description(
+        self, name: str, version: int, description: str
+    ) -> ModelVersion:
+        """
+        Update description associated with a model version in backend.
+        :param name: Registered model name.
+        :param version: Registered model version.
+        :param description: New model description.
+        :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            update_time = datetime.now()
+            sql_model_version = self._get_model_version(session, name, version)
+            sql_model_version.description = description
+            sql_model_version.last_updated_time = update_time
+            self._save_to_db(session, sql_model_version)
+            return sql_model_version.to_submarine_entity()
+
+    def transition_model_version_stage(self, name: str, version: int, stage: str) -> ModelVersion:
+        """
+        Update model version stage.
+        :param name: Registered model name.
+        :param version: Registered model version.
+        :param stage: New desired stage for this model version.
+        :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            last_updated_time = datetime.now()
+
+            sql_model_version = self._get_model_version(session, name, version)
+            sql_model_version.current_stage = get_canonical_stage(stage)
+            sql_model_version.last_updated_time = last_updated_time
+            sql_registered_model = sql_model_version.registered_model
+            sql_registered_model.last_updated_time = last_updated_time
+            self._save_to_db(session, [sql_model_version, sql_registered_model])
+            return sql_model_version.to_submarine_entity()
+
+    def delete_model_version(self, name: str, version: int) -> None:
+        """
+        Delete model version in backend.
+        :param name: Registered model name.
+        :param version: Registered model version.
+        :return: None
+        """
+        with self.ManagedSessionMaker() as session:
+            updated_time = datetime.now()
+            sql_model_version = self._get_model_version(session, name, version)
+            sql_registered_model = sql_model_version.registered_model
+            sql_registered_model.last_updated_time = updated_time
+            session.delete(sql_model_version)
+            self._save_to_db(session, sql_registered_model)
+            session.flush()
+
+    def get_model_version(self, name: str, version: int) -> ModelVersion:
+        """
+        Get the model version instance by name and version.
+        :param name: Registered model name.
+        :param version: Registered model version.
+        :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            sql_model_version = self._get_model_version(session, name, version, True)
+            return sql_model_version.to_submarine_entity()
+
+    def list_model_version(self, name: str, filter_tags: List[str] = None) -> List[ModelVersion]:
+        """
+        List of all model versions that satisfy the filter criteria.
+        :param name: Registered model name.
+        :param filter_tags: Filter tags, defaults not to filter any tags.
+        :return: A List of :py:class:`submarine.entities.model_registry.ModelVersion` objects
+                that satisfy the search expressions.
+        """
+        conditions = [SqlModelVersion.name == name]
+        if filter_tags is not None:
+            conditions = [
+                SqlModelVersion.model_tags.any(SqlModelTag.tag.contains(tag)) for tag in filter_tags

Review comment:
       Thanks for point out this problem




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] KUAN-HSUN-LI commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
KUAN-HSUN-LI commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r711603529



##########
File path: submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
##########
@@ -0,0 +1,570 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from contextlib import contextmanager
+from datetime import datetime
+from typing import List, Union
+
+import sqlalchemy
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.orm.session import Session, sessionmaker
+from sqlalchemy.orm.strategy_options import _UnboundLoad
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+from submarine.entities.model_registry.model_version_stages import (
+    STAGE_DELETED_INTERNAL,
+    get_canonical_stage,
+)
+from submarine.exceptions import SubmarineException
+from submarine.store.database.models import (
+    Base,
+    SqlModelTag,
+    SqlModelVersion,
+    SqlRegisteredModel,
+    SqlRegisteredModelTag,
+)
+from submarine.store.model_registry.abstract_store import AbstractStore
+from submarine.utils import extract_db_type_from_uri
+from submarine.utils.validation import (
+    validate_model_name,
+    validate_model_version,
+    validate_tag,
+    validate_tags,
+)
+
+_logger = logging.getLogger(__name__)
+
+
+class SqlAlchemyStore(AbstractStore):
+    def __init__(self, db_uri: str) -> None:
+        """
+        Create a database backed store.
+        :param db_uri: The SQLAlchemy database URI string to connect to the database. See
+                       the `SQLAlchemy docs
+                       <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
+                       for format specifications. Submarine supports the dialects ``mysql``.
+        """
+        super(SqlAlchemyStore, self).__init__()
+
+        self.db_uri = db_uri
+        self.db_type = extract_db_type_from_uri(db_uri)
+        self.engine = sqlalchemy.create_engine(db_uri, pool_pre_ping=True)
+        insp = sqlalchemy.inspect(self.engine)
+
+        # Verify that all model registry tables exist.
+        expected_tables = {
+            SqlModelVersion.__tablename__,
+            SqlModelTag.__tablename__,
+            SqlRegisteredModel.__tablename__,
+            SqlRegisteredModelTag.__tablename__,
+        }
+        if len(expected_tables & set(insp.get_table_names())) == 0:
+            SqlAlchemyStore._initialize_tables(self.engine)
+        Base.metadata.bind = self.engine
+        SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
+        self.ManagedSessionMaker = self._get_managed_session_maker(SessionMaker)
+
+    @staticmethod
+    def _initialize_tables(engine: Engine):
+        _logger.info("Creating initial Submarine database tables...")
+        Base.metadata.create_all(engine)
+
+    @staticmethod
+    def _get_managed_session_maker(SessionMaker: sessionmaker):
+        """
+        Creates a factory for producing exception-safe SQLAlchemy sessions that are made available
+        using a context manager. Any session produced by this factory is automatically committed
+        if no exceptions are encountered within its associated context. If an exception is
+        encountered, the session is rolled back. Finally, any session produced by this factory is
+        automatically closed when the session's associated context is exited.
+        """
+
+        @contextmanager
+        def make_managed_session():
+            """Provide a transactional scope around a series of operations."""
+            session: Session = SessionMaker()
+            try:
+                yield session
+                session.commit()
+            except SubmarineException:
+                session.rollback()
+                raise
+            except Exception as e:
+                session.rollback()
+                raise SubmarineException(message=e)
+            finally:
+                session.close()
+
+        return make_managed_session
+
+    @staticmethod
+    def _get_eager_registered_model_query_options() -> List[_UnboundLoad]:
+        """
+        :return A list of SQLAlchemy query options that can be used to eagerly
+                load the following registered model attributes
+                when fetching a registered model: ``registered_model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlRegisteredModel.registered_model_tags)]
+
+    @staticmethod
+    def _get_eager_model_version_query_options():
+        """
+        :return: A list of SQLAlchemy query options that can be used to eagerly
+                load the following model version attributes
+                when fetching a model version: ``model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlModelVersion.model_tags)]
+
+    def _save_to_db(self, session: Session, objs: Union[list, object]) -> None:
+        """
+        Store in db
+        """
+        if type(objs) is list:
+            session.add_all(objs)
+        else:
+            # single object
+            session.add(objs)
+
+    def create_registered_model(
+        self, name: str, description: str = None, tags: List[str] = None
+    ) -> RegisteredModel:
+        """
+        Create a new registered model in backend store.
+        :param name: Name of the new model. This is expected to be unique in the backend store.
+        :param description: Description of the model.
+        :param tags: A list of string associated with this registered model.
+        :return: A single object of :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 created in the backend.
+        """
+        validate_model_name(name)
+        validate_tags(tags)
+
+        with self.ManagedSessionMaker() as session:
+            try:
+                creation_time = datetime.now()
+                registered_model = SqlRegisteredModel(
+                    name=name,
+                    creation_time=creation_time,
+                    last_updated_time=creation_time,
+                    description=description,
+                    registered_model_tags=[SqlRegisteredModelTag(tag=tag) for tag in tags or []],
+                )
+                self._save_to_db(session, registered_model)
+                session.flush()
+                return registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    message=f"Registered Model (name={name}) already exists.\nError: {str(e)}"
+                )
+
+    @classmethod
+    def _get_registered_model(
+        cls, session: Session, name: str, eager: bool = False
+    ) -> SqlRegisteredModel:
+        """
+        :param eager: If ``True``, eagerly loads the registered model's tags.
+                      If ``False``, these attributes are not eagerly loaded and
+                      will be loaded when their corresponding object properties
+                      are accessed from the resulting ``SqlRegisteredModel`` object.
+        """
+        validate_model_name(name)
+        query_options = cls._get_eager_registered_model_query_options() if eager else []
+        models: List[SqlRegisteredModel] = (
+            session.query(SqlRegisteredModel)
+            .options(*query_options)
+            .filter(SqlRegisteredModel.name == name)
+            .all()
+        )
+
+        if len(models) == 0:
+            raise SubmarineException(f"Registered Model with name={name} not found")
+        elif len(models) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model with name={name}.\nFound {len(models)}"
+            )
+        else:
+            return models[0]
+
+    def update_registered_model_discription(self, name: str, description: str) -> RegisteredModel:
+        """
+        Update description of the registered model.
+        :param name: Registered model name.
+        :param description: New description.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            sql_registered_model.description = description
+            sql_registered_model.last_updated_time = datetime.now()
+            self._save_to_db(session, sql_registered_model)
+            session.flush()
+            return sql_registered_model.to_submarine_entity()
+
+    def rename_registered_model(self, name: str, new_name: str) -> RegisteredModel:
+        """
+        Rename the registered model.
+        :param name: Registered model name.
+        :param new_name: New proposed name.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        validate_model_name(new_name)
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            try:
+                update_time = datetime.now()
+                sql_registered_model.name = new_name
+                sql_registered_model.last_updated_time = update_time
+                for sql_model_version in sql_registered_model.model_versions:
+                    sql_model_version.name = new_name
+                    sql_model_version.last_updated_time = update_time
+                self._save_to_db(
+                    session, [sql_registered_model] + sql_registered_model.model_versions
+                )
+                session.flush()
+                return sql_registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    message=f"Registered Model (name={name}) already exists. Error: {str(e)}"
+                )
+
+    def delete_registered_model(self, name: str) -> None:
+        """
+        Delete the registered model.
+        :param name: Registered model name.
+        :return: None
+        """
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            session.delete(sql_registered_model)
+
+    def list_registered_model(
+        self, filter_str: str = None, filter_tags: List[str] = None
+    ) -> List[RegisteredModel]:
+        """
+        List of all registered models.
+        :param filter_string: Filter query string, defaults to searching all registered models.
+        :param filter_tags: Filter tags, defaults not to filter any tags.
+        :return: A List of :py:class:`submarine.entities.model_registry.RegisteredModel` objects
+                that satisfy the search expressions.
+        """
+        conditions = []
+        if filter_tags is not None:
+            conditions = [
+                SqlRegisteredModel.registered_model_tags.any(
+                    SqlRegisteredModelTag.tag.contains(tag)
+                )
+                for tag in filter_tags
+            ]
+        if filter_str is not None:
+            conditions.append(SqlRegisteredModel.name.startswith(filter_str))
+        with self.ManagedSessionMaker() as session:
+            registered_models = session.query(SqlRegisteredModel).filter(*conditions).all()
+            return [
+                registered_model.to_submarine_entity() for registered_model in registered_models
+            ]
+
+    def get_registered_model(self, name: str) -> RegisteredModel:
+        """
+        Get registered model instance by name.
+        :param name: Registered model name.
+        :return: A single :py:class:`submarine.entities.model_registry.RegisteredModel` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            return self._get_registered_model(session, name, True).to_submarine_entity()
+
+    @classmethod
+    def _get_registered_model_tag(
+        cls, session: Session, name: str, tag: str
+    ) -> SqlRegisteredModelTag:
+        tags = (
+            session.query(SqlRegisteredModelTag)
+            .filter(SqlRegisteredModelTag.name == name, SqlRegisteredModelTag.tag == tag)
+            .all()
+        )
+        if len(tags) == 0:
+            raise SubmarineException(
+                message=f"Registered model tag with name={name}, tag={tag} not found"
+            )
+        elif len(tags) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model version tag with name={name}, tag={tag}. Found"
+                f" {len(tags)}."
+            )
+        else:
+            return tags[0]
+
+    def add_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Add a tag for the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            session.merge(SqlRegisteredModelTag(name=name, tag=tag))
+
+    def delete_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Delete a tag associated with the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            existing_tag = self._get_registered_model_tag(session, name, tag)
+            session.delete(existing_tag)
+
+    def create_model_version(
+        self,
+        name: str,
+        source: str,
+        user_id: str,
+        experiment_id: str,
+        dataset: str = None,
+        description: str = None,
+        tags: List[str] = None,
+    ) -> ModelVersion:
+        """
+        Create a new model version
+        :param name: Registered model name.
+        :param user_id: User ID from server that created this model
+        :param experiment_id: Experiment ID which this model is created.
+        :param source: Source path where this model is stored.
+        :param dataset: Dataset which this model is used.
+        :param description: Description of the version.
+        :param tags: A list of string associated with this model.
+        :return: A single object of :py:class:`submarine.entities.model_registry.ModelVersion`
+                 created in the backend.
+        """
+
+        def next_version(sql_registered_model: SqlRegisteredModel) -> int:
+            if sql_registered_model.model_versions:
+                return max([mv.version for mv in sql_registered_model.model_versions]) + 1
+            else:
+                return 1
+
+        validate_model_name(name)
+        validate_tags(tags)
+        with self.ManagedSessionMaker() as session:
+            try:
+                creation_time = datetime.now()
+                sql_registered_model = self._get_registered_model(session, name)
+                sql_registered_model.last_updated_time = creation_time
+                model_version = SqlModelVersion(
+                    name=name,
+                    version=next_version(sql_registered_model),
+                    source=source,
+                    user_id=user_id,
+                    experiment_id=experiment_id,
+                    creation_time=creation_time,
+                    last_updated_time=creation_time,
+                    dataset=dataset,
+                    description=description,
+                    model_tags=[SqlModelTag(tag=tag) for tag in tags or []],
+                )
+                self._save_to_db(session, [sql_registered_model, model_version])
+                session.flush()
+                return model_version.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError:
+                raise SubmarineException(message=f"Model Version creation error (name={name}).")
+
+    @classmethod
+    def _get_model_version(
+        cls, session: Session, name: str, version: int, eager: bool = False
+    ) -> SqlModelVersion:
+        """
+        :param eager: If ``True``, eagerly loads the model version's tags.
+                      If ``False``, these attributes are not eagerly loaded and
+                      will be loaded when their corresponding object properties
+                      are accessed from the resulting ``SqlModelVersion`` object.
+        """
+        validate_model_name(name)
+        validate_model_version(version)
+        query_options = cls._get_eager_model_version_query_options() if eager else []
+        conditions = [
+            SqlModelVersion.name == name,
+            SqlModelVersion.version == version,
+            SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL,
+        ]
+
+        versions: List[SqlModelVersion] = (
+            session.query(SqlModelVersion).options(*query_options).filter(*conditions).all()
+        )
+        if len(versions) == 0:
+            raise SubmarineException(f"Model Version (name={name}, version={version}) not found.")
+        elif len(versions) > 1:
+            raise SubmarineException(
+                f"Expected only 1 model version with (name={name}, version={versions}). Found"
+                f" {len(versions)}."
+            )
+        else:
+            return versions[0]
+
+    def update_model_version_description(
+        self, name: str, version: int, description: str
+    ) -> ModelVersion:
+        """
+        Update description associated with a model version in backend.
+        :param name: Registered model name.
+        :param version: Registered model version.
+        :param description: New model description.
+        :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            update_time = datetime.now()
+            sql_model_version = self._get_model_version(session, name, version)
+            sql_model_version.description = description
+            sql_model_version.last_updated_time = update_time
+            self._save_to_db(session, sql_model_version)

Review comment:
       I will add the check




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] asfgit closed pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
asfgit closed pull request #752:
URL: https://github.com/apache/submarine/pull/752


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] jeff-901 commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
jeff-901 commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r713670191



##########
File path: submarine-sdk/pysubmarine/submarine/utils/validation.py
##########
@@ -116,8 +117,48 @@ def validate_param(key, value):
     _validate_length_limit("Param value", MAX_PARAM_VAL_LENGTH, str(value))
 
 
+def validate_tags(tags: Optional[List[str]]) -> None:
+    if tags is not None and not isinstance(tags, list):
+        raise SubmarineException("parameter tags must be list or None.")
+    for tag in tags or []:
+        validate_tag(tag)
+
+
+def validate_tag(tag: str) -> None:
+    """Check that `tag` is a valid tag value and raise an exception if it isn't."""
+    # Reuse param & metric check.
+    if tag is None or tag == "":
+        raise SubmarineException("Tag cannot be empty.")
+    if not _VALID_PARAM_AND_METRIC_NAMES.match(tag):
+        raise SubmarineException("Invalid tag name: '%s'. %s" % (tag, _BAD_CHARACTERS_MESSAGE))
+
+
+def validate_model_name(model_name: str) -> None:
+    if model_name is None or model_name == "":
+        raise SubmarineException("Registered model name cannot be empty.")
+
+
+def validate_model_version(model_version: int) -> None:

Review comment:
       ok




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] KUAN-HSUN-LI commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
KUAN-HSUN-LI commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r713236288



##########
File path: submarine-sdk/pysubmarine/submarine/utils/validation.py
##########
@@ -116,8 +117,48 @@ def validate_param(key, value):
     _validate_length_limit("Param value", MAX_PARAM_VAL_LENGTH, str(value))
 
 
+def validate_tags(tags: Optional[List[str]]) -> None:
+    if tags is not None and not isinstance(tags, list):
+        raise SubmarineException("parameter tags must be list or None.")
+    for tag in tags or []:
+        validate_tag(tag)
+
+
+def validate_tag(tag: str) -> None:
+    """Check that `tag` is a valid tag value and raise an exception if it isn't."""
+    # Reuse param & metric check.
+    if tag is None or tag == "":
+        raise SubmarineException("Tag cannot be empty.")
+    if not _VALID_PARAM_AND_METRIC_NAMES.match(tag):
+        raise SubmarineException("Invalid tag name: '%s'. %s" % (tag, _BAD_CHARACTERS_MESSAGE))
+
+
+def validate_model_name(model_name: str) -> None:
+    if model_name is None or model_name == "":
+        raise SubmarineException("Registered model name cannot be empty.")
+
+
+def validate_model_version(model_version: int) -> None:

Review comment:
       I think this function name is fine.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] KUAN-HSUN-LI commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
KUAN-HSUN-LI commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r711659973



##########
File path: submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
##########
@@ -0,0 +1,234 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABCMeta, abstractmethod
+from typing import List
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+
+
+class AbstractStore:
+    """
+    Abstract class for Backend model registry
+    This class defines the API interface for frontends to connect with various types of backends.
+    """
+
+    __metaclass__ = ABCMeta
+
+    def __init__(self) -> None:
+        """
+        Empty constructor for now. This is deliberately not marked as abstract, else every
+        derived class would be forced to create one.
+        """
+        pass
+
+    @abstractmethod
+    def create_registered_model(

Review comment:
       model registry including `registered_model`, `registered_model_tag`, `model_version`, `model_version_tag` four tables.
   tracking including `param`, `metric`, `experiment` three tables.
   In my opinion, it is more clear to put these functions in different directories. In this PR, I only worked on the model registry directory maybe I should also refactor the tracking directory and implement the `experiment` table.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] KUAN-HSUN-LI commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
KUAN-HSUN-LI commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r711659506



##########
File path: submarine-sdk/pysubmarine/github-actions/test-requirements.txt
##########
@@ -23,11 +23,12 @@ pytest==3.2.1
 pytest-cov==2.6.0
 pytest-localserver==0.5.0
 pylint==2.5.2
-sqlalchemy==1.3.0
+sqlalchemy >= 1.4.0

Review comment:
       Sure, I will remove it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] jeff-901 commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
jeff-901 commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r711577615



##########
File path: submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
##########
@@ -0,0 +1,570 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from contextlib import contextmanager
+from datetime import datetime
+from typing import List, Union
+
+import sqlalchemy
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.orm.session import Session, sessionmaker
+from sqlalchemy.orm.strategy_options import _UnboundLoad
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+from submarine.entities.model_registry.model_version_stages import (
+    STAGE_DELETED_INTERNAL,
+    get_canonical_stage,
+)
+from submarine.exceptions import SubmarineException
+from submarine.store.database.models import (
+    Base,
+    SqlModelTag,
+    SqlModelVersion,
+    SqlRegisteredModel,
+    SqlRegisteredModelTag,
+)
+from submarine.store.model_registry.abstract_store import AbstractStore
+from submarine.utils import extract_db_type_from_uri
+from submarine.utils.validation import (
+    validate_model_name,
+    validate_model_version,
+    validate_tag,
+    validate_tags,
+)
+
+_logger = logging.getLogger(__name__)
+
+
+class SqlAlchemyStore(AbstractStore):
+    def __init__(self, db_uri: str) -> None:
+        """
+        Create a database backed store.
+        :param db_uri: The SQLAlchemy database URI string to connect to the database. See
+                       the `SQLAlchemy docs
+                       <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
+                       for format specifications. Submarine supports the dialects ``mysql``.
+        """
+        super(SqlAlchemyStore, self).__init__()
+
+        self.db_uri = db_uri
+        self.db_type = extract_db_type_from_uri(db_uri)
+        self.engine = sqlalchemy.create_engine(db_uri, pool_pre_ping=True)
+        insp = sqlalchemy.inspect(self.engine)
+
+        # Verify that all model registry tables exist.
+        expected_tables = {
+            SqlModelVersion.__tablename__,
+            SqlModelTag.__tablename__,
+            SqlRegisteredModel.__tablename__,
+            SqlRegisteredModelTag.__tablename__,
+        }
+        if len(expected_tables & set(insp.get_table_names())) == 0:
+            SqlAlchemyStore._initialize_tables(self.engine)
+        Base.metadata.bind = self.engine
+        SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
+        self.ManagedSessionMaker = self._get_managed_session_maker(SessionMaker)
+
+    @staticmethod
+    def _initialize_tables(engine: Engine):
+        _logger.info("Creating initial Submarine database tables...")
+        Base.metadata.create_all(engine)
+
+    @staticmethod
+    def _get_managed_session_maker(SessionMaker: sessionmaker):
+        """
+        Creates a factory for producing exception-safe SQLAlchemy sessions that are made available
+        using a context manager. Any session produced by this factory is automatically committed
+        if no exceptions are encountered within its associated context. If an exception is
+        encountered, the session is rolled back. Finally, any session produced by this factory is
+        automatically closed when the session's associated context is exited.
+        """
+
+        @contextmanager
+        def make_managed_session():
+            """Provide a transactional scope around a series of operations."""
+            session: Session = SessionMaker()
+            try:
+                yield session
+                session.commit()
+            except SubmarineException:
+                session.rollback()
+                raise
+            except Exception as e:
+                session.rollback()
+                raise SubmarineException(message=e)
+            finally:
+                session.close()
+
+        return make_managed_session
+
+    @staticmethod
+    def _get_eager_registered_model_query_options() -> List[_UnboundLoad]:
+        """
+        :return A list of SQLAlchemy query options that can be used to eagerly
+                load the following registered model attributes
+                when fetching a registered model: ``registered_model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlRegisteredModel.registered_model_tags)]
+
+    @staticmethod
+    def _get_eager_model_version_query_options():
+        """
+        :return: A list of SQLAlchemy query options that can be used to eagerly
+                load the following model version attributes
+                when fetching a model version: ``model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlModelVersion.model_tags)]
+
+    def _save_to_db(self, session: Session, objs: Union[list, object]) -> None:
+        """
+        Store in db
+        """
+        if type(objs) is list:
+            session.add_all(objs)
+        else:
+            # single object
+            session.add(objs)
+
+    def create_registered_model(
+        self, name: str, description: str = None, tags: List[str] = None
+    ) -> RegisteredModel:
+        """
+        Create a new registered model in backend store.
+        :param name: Name of the new model. This is expected to be unique in the backend store.
+        :param description: Description of the model.
+        :param tags: A list of string associated with this registered model.
+        :return: A single object of :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 created in the backend.
+        """
+        validate_model_name(name)
+        validate_tags(tags)
+
+        with self.ManagedSessionMaker() as session:
+            try:
+                creation_time = datetime.now()
+                registered_model = SqlRegisteredModel(
+                    name=name,
+                    creation_time=creation_time,
+                    last_updated_time=creation_time,
+                    description=description,
+                    registered_model_tags=[SqlRegisteredModelTag(tag=tag) for tag in tags or []],
+                )
+                self._save_to_db(session, registered_model)
+                session.flush()
+                return registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    message=f"Registered Model (name={name}) already exists.\nError: {str(e)}"
+                )
+
+    @classmethod
+    def _get_registered_model(
+        cls, session: Session, name: str, eager: bool = False
+    ) -> SqlRegisteredModel:
+        """
+        :param eager: If ``True``, eagerly loads the registered model's tags.
+                      If ``False``, these attributes are not eagerly loaded and
+                      will be loaded when their corresponding object properties
+                      are accessed from the resulting ``SqlRegisteredModel`` object.
+        """
+        validate_model_name(name)
+        query_options = cls._get_eager_registered_model_query_options() if eager else []
+        models: List[SqlRegisteredModel] = (
+            session.query(SqlRegisteredModel)
+            .options(*query_options)
+            .filter(SqlRegisteredModel.name == name)
+            .all()
+        )
+
+        if len(models) == 0:
+            raise SubmarineException(f"Registered Model with name={name} not found")
+        elif len(models) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model with name={name}.\nFound {len(models)}"
+            )
+        else:
+            return models[0]
+
+    def update_registered_model_discription(self, name: str, description: str) -> RegisteredModel:
+        """
+        Update description of the registered model.
+        :param name: Registered model name.
+        :param description: New description.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            sql_registered_model.description = description
+            sql_registered_model.last_updated_time = datetime.now()
+            self._save_to_db(session, sql_registered_model)
+            session.flush()
+            return sql_registered_model.to_submarine_entity()
+
+    def rename_registered_model(self, name: str, new_name: str) -> RegisteredModel:
+        """
+        Rename the registered model.
+        :param name: Registered model name.
+        :param new_name: New proposed name.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        validate_model_name(new_name)
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            try:
+                update_time = datetime.now()
+                sql_registered_model.name = new_name
+                sql_registered_model.last_updated_time = update_time
+                for sql_model_version in sql_registered_model.model_versions:
+                    sql_model_version.name = new_name
+                    sql_model_version.last_updated_time = update_time
+                self._save_to_db(
+                    session, [sql_registered_model] + sql_registered_model.model_versions
+                )
+                session.flush()
+                return sql_registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    message=f"Registered Model (name={name}) already exists. Error: {str(e)}"
+                )
+
+    def delete_registered_model(self, name: str) -> None:
+        """
+        Delete the registered model.
+        :param name: Registered model name.
+        :return: None
+        """
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            session.delete(sql_registered_model)
+
+    def list_registered_model(
+        self, filter_str: str = None, filter_tags: List[str] = None
+    ) -> List[RegisteredModel]:
+        """
+        List of all registered models.
+        :param filter_string: Filter query string, defaults to searching all registered models.
+        :param filter_tags: Filter tags, defaults not to filter any tags.
+        :return: A List of :py:class:`submarine.entities.model_registry.RegisteredModel` objects
+                that satisfy the search expressions.
+        """
+        conditions = []
+        if filter_tags is not None:
+            conditions = [
+                SqlRegisteredModel.registered_model_tags.any(
+                    SqlRegisteredModelTag.tag.contains(tag)
+                )
+                for tag in filter_tags
+            ]
+        if filter_str is not None:
+            conditions.append(SqlRegisteredModel.name.startswith(filter_str))
+        with self.ManagedSessionMaker() as session:
+            registered_models = session.query(SqlRegisteredModel).filter(*conditions).all()
+            return [
+                registered_model.to_submarine_entity() for registered_model in registered_models
+            ]
+
+    def get_registered_model(self, name: str) -> RegisteredModel:
+        """
+        Get registered model instance by name.
+        :param name: Registered model name.
+        :return: A single :py:class:`submarine.entities.model_registry.RegisteredModel` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            return self._get_registered_model(session, name, True).to_submarine_entity()
+
+    @classmethod
+    def _get_registered_model_tag(
+        cls, session: Session, name: str, tag: str
+    ) -> SqlRegisteredModelTag:
+        tags = (
+            session.query(SqlRegisteredModelTag)
+            .filter(SqlRegisteredModelTag.name == name, SqlRegisteredModelTag.tag == tag)
+            .all()
+        )
+        if len(tags) == 0:
+            raise SubmarineException(
+                message=f"Registered model tag with name={name}, tag={tag} not found"
+            )
+        elif len(tags) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model version tag with name={name}, tag={tag}. Found"
+                f" {len(tags)}."
+            )
+        else:
+            return tags[0]
+
+    def add_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Add a tag for the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            session.merge(SqlRegisteredModelTag(name=name, tag=tag))
+
+    def delete_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Delete a tag associated with the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            existing_tag = self._get_registered_model_tag(session, name, tag)
+            session.delete(existing_tag)
+
+    def create_model_version(
+        self,
+        name: str,
+        source: str,
+        user_id: str,
+        experiment_id: str,
+        dataset: str = None,
+        description: str = None,
+        tags: List[str] = None,
+    ) -> ModelVersion:
+        """
+        Create a new model version
+        :param name: Registered model name.
+        :param user_id: User ID from server that created this model
+        :param experiment_id: Experiment ID which this model is created.
+        :param source: Source path where this model is stored.
+        :param dataset: Dataset which this model is used.
+        :param description: Description of the version.
+        :param tags: A list of string associated with this model.
+        :return: A single object of :py:class:`submarine.entities.model_registry.ModelVersion`
+                 created in the backend.
+        """
+
+        def next_version(sql_registered_model: SqlRegisteredModel) -> int:
+            if sql_registered_model.model_versions:
+                return max([mv.version for mv in sql_registered_model.model_versions]) + 1
+            else:
+                return 1
+
+        validate_model_name(name)
+        validate_tags(tags)
+        with self.ManagedSessionMaker() as session:
+            try:
+                creation_time = datetime.now()
+                sql_registered_model = self._get_registered_model(session, name)
+                sql_registered_model.last_updated_time = creation_time
+                model_version = SqlModelVersion(
+                    name=name,
+                    version=next_version(sql_registered_model),
+                    source=source,
+                    user_id=user_id,
+                    experiment_id=experiment_id,
+                    creation_time=creation_time,
+                    last_updated_time=creation_time,
+                    dataset=dataset,
+                    description=description,
+                    model_tags=[SqlModelTag(tag=tag) for tag in tags or []],
+                )
+                self._save_to_db(session, [sql_registered_model, model_version])
+                session.flush()
+                return model_version.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError:
+                raise SubmarineException(message=f"Model Version creation error (name={name}).")
+
+    @classmethod
+    def _get_model_version(
+        cls, session: Session, name: str, version: int, eager: bool = False
+    ) -> SqlModelVersion:
+        """
+        :param eager: If ``True``, eagerly loads the model version's tags.
+                      If ``False``, these attributes are not eagerly loaded and
+                      will be loaded when their corresponding object properties
+                      are accessed from the resulting ``SqlModelVersion`` object.
+        """
+        validate_model_name(name)
+        validate_model_version(version)
+        query_options = cls._get_eager_model_version_query_options() if eager else []
+        conditions = [
+            SqlModelVersion.name == name,
+            SqlModelVersion.version == version,
+            SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL,
+        ]
+
+        versions: List[SqlModelVersion] = (
+            session.query(SqlModelVersion).options(*query_options).filter(*conditions).all()
+        )
+        if len(versions) == 0:
+            raise SubmarineException(f"Model Version (name={name}, version={version}) not found.")
+        elif len(versions) > 1:
+            raise SubmarineException(
+                f"Expected only 1 model version with (name={name}, version={versions}). Found"
+                f" {len(versions)}."
+            )
+        else:
+            return versions[0]
+
+    def update_model_version_description(
+        self, name: str, version: int, description: str
+    ) -> ModelVersion:
+        """
+        Update description associated with a model version in backend.
+        :param name: Registered model name.
+        :param version: Registered model version.
+        :param description: New model description.
+        :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            update_time = datetime.now()
+            sql_model_version = self._get_model_version(session, name, version)
+            sql_model_version.description = description
+            sql_model_version.last_updated_time = update_time
+            self._save_to_db(session, sql_model_version)

Review comment:
       Is it necessary to check the length of description?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] KUAN-HSUN-LI commented on pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
KUAN-HSUN-LI commented on pull request #752:
URL: https://github.com/apache/submarine/pull/752#issuecomment-922758424


   @pingsutw @jeff-901 Thanks for your review. I have replaced `registered model` with `model container` and replaced `model version` with `model metadata`. Additionally, I have fixed serval bugs.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] jeff-901 commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
jeff-901 commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r711578893



##########
File path: submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
##########
@@ -0,0 +1,570 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from contextlib import contextmanager
+from datetime import datetime
+from typing import List, Union
+
+import sqlalchemy
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.orm.session import Session, sessionmaker
+from sqlalchemy.orm.strategy_options import _UnboundLoad
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+from submarine.entities.model_registry.model_version_stages import (
+    STAGE_DELETED_INTERNAL,
+    get_canonical_stage,
+)
+from submarine.exceptions import SubmarineException
+from submarine.store.database.models import (
+    Base,
+    SqlModelTag,
+    SqlModelVersion,
+    SqlRegisteredModel,
+    SqlRegisteredModelTag,
+)
+from submarine.store.model_registry.abstract_store import AbstractStore
+from submarine.utils import extract_db_type_from_uri
+from submarine.utils.validation import (
+    validate_model_name,
+    validate_model_version,
+    validate_tag,
+    validate_tags,
+)
+
+_logger = logging.getLogger(__name__)
+
+
+class SqlAlchemyStore(AbstractStore):
+    def __init__(self, db_uri: str) -> None:
+        """
+        Create a database backed store.
+        :param db_uri: The SQLAlchemy database URI string to connect to the database. See
+                       the `SQLAlchemy docs
+                       <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
+                       for format specifications. Submarine supports the dialects ``mysql``.
+        """
+        super(SqlAlchemyStore, self).__init__()
+
+        self.db_uri = db_uri
+        self.db_type = extract_db_type_from_uri(db_uri)
+        self.engine = sqlalchemy.create_engine(db_uri, pool_pre_ping=True)
+        insp = sqlalchemy.inspect(self.engine)
+
+        # Verify that all model registry tables exist.
+        expected_tables = {
+            SqlModelVersion.__tablename__,
+            SqlModelTag.__tablename__,
+            SqlRegisteredModel.__tablename__,
+            SqlRegisteredModelTag.__tablename__,
+        }
+        if len(expected_tables & set(insp.get_table_names())) == 0:
+            SqlAlchemyStore._initialize_tables(self.engine)
+        Base.metadata.bind = self.engine
+        SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
+        self.ManagedSessionMaker = self._get_managed_session_maker(SessionMaker)
+
+    @staticmethod
+    def _initialize_tables(engine: Engine):
+        _logger.info("Creating initial Submarine database tables...")
+        Base.metadata.create_all(engine)
+
+    @staticmethod
+    def _get_managed_session_maker(SessionMaker: sessionmaker):
+        """
+        Creates a factory for producing exception-safe SQLAlchemy sessions that are made available
+        using a context manager. Any session produced by this factory is automatically committed
+        if no exceptions are encountered within its associated context. If an exception is
+        encountered, the session is rolled back. Finally, any session produced by this factory is
+        automatically closed when the session's associated context is exited.
+        """
+
+        @contextmanager
+        def make_managed_session():
+            """Provide a transactional scope around a series of operations."""
+            session: Session = SessionMaker()
+            try:
+                yield session
+                session.commit()
+            except SubmarineException:
+                session.rollback()
+                raise
+            except Exception as e:
+                session.rollback()
+                raise SubmarineException(message=e)
+            finally:
+                session.close()
+
+        return make_managed_session
+
+    @staticmethod
+    def _get_eager_registered_model_query_options() -> List[_UnboundLoad]:
+        """
+        :return A list of SQLAlchemy query options that can be used to eagerly
+                load the following registered model attributes
+                when fetching a registered model: ``registered_model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlRegisteredModel.registered_model_tags)]
+
+    @staticmethod
+    def _get_eager_model_version_query_options():
+        """
+        :return: A list of SQLAlchemy query options that can be used to eagerly
+                load the following model version attributes
+                when fetching a model version: ``model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlModelVersion.model_tags)]
+
+    def _save_to_db(self, session: Session, objs: Union[list, object]) -> None:
+        """
+        Store in db
+        """
+        if type(objs) is list:
+            session.add_all(objs)
+        else:
+            # single object
+            session.add(objs)
+
+    def create_registered_model(
+        self, name: str, description: str = None, tags: List[str] = None
+    ) -> RegisteredModel:
+        """
+        Create a new registered model in backend store.
+        :param name: Name of the new model. This is expected to be unique in the backend store.
+        :param description: Description of the model.
+        :param tags: A list of string associated with this registered model.
+        :return: A single object of :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 created in the backend.
+        """
+        validate_model_name(name)
+        validate_tags(tags)
+
+        with self.ManagedSessionMaker() as session:
+            try:
+                creation_time = datetime.now()
+                registered_model = SqlRegisteredModel(
+                    name=name,
+                    creation_time=creation_time,
+                    last_updated_time=creation_time,
+                    description=description,
+                    registered_model_tags=[SqlRegisteredModelTag(tag=tag) for tag in tags or []],
+                )
+                self._save_to_db(session, registered_model)
+                session.flush()
+                return registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    message=f"Registered Model (name={name}) already exists.\nError: {str(e)}"
+                )
+
+    @classmethod
+    def _get_registered_model(
+        cls, session: Session, name: str, eager: bool = False
+    ) -> SqlRegisteredModel:
+        """
+        :param eager: If ``True``, eagerly loads the registered model's tags.
+                      If ``False``, these attributes are not eagerly loaded and
+                      will be loaded when their corresponding object properties
+                      are accessed from the resulting ``SqlRegisteredModel`` object.
+        """
+        validate_model_name(name)
+        query_options = cls._get_eager_registered_model_query_options() if eager else []
+        models: List[SqlRegisteredModel] = (
+            session.query(SqlRegisteredModel)
+            .options(*query_options)
+            .filter(SqlRegisteredModel.name == name)
+            .all()
+        )
+
+        if len(models) == 0:
+            raise SubmarineException(f"Registered Model with name={name} not found")
+        elif len(models) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model with name={name}.\nFound {len(models)}"
+            )
+        else:
+            return models[0]
+
+    def update_registered_model_discription(self, name: str, description: str) -> RegisteredModel:
+        """
+        Update description of the registered model.
+        :param name: Registered model name.
+        :param description: New description.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            sql_registered_model.description = description
+            sql_registered_model.last_updated_time = datetime.now()
+            self._save_to_db(session, sql_registered_model)
+            session.flush()
+            return sql_registered_model.to_submarine_entity()
+
+    def rename_registered_model(self, name: str, new_name: str) -> RegisteredModel:
+        """
+        Rename the registered model.
+        :param name: Registered model name.
+        :param new_name: New proposed name.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        validate_model_name(new_name)
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            try:
+                update_time = datetime.now()
+                sql_registered_model.name = new_name
+                sql_registered_model.last_updated_time = update_time
+                for sql_model_version in sql_registered_model.model_versions:
+                    sql_model_version.name = new_name
+                    sql_model_version.last_updated_time = update_time
+                self._save_to_db(
+                    session, [sql_registered_model] + sql_registered_model.model_versions
+                )
+                session.flush()
+                return sql_registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    message=f"Registered Model (name={name}) already exists. Error: {str(e)}"
+                )
+
+    def delete_registered_model(self, name: str) -> None:
+        """
+        Delete the registered model.
+        :param name: Registered model name.
+        :return: None
+        """
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            session.delete(sql_registered_model)
+
+    def list_registered_model(
+        self, filter_str: str = None, filter_tags: List[str] = None
+    ) -> List[RegisteredModel]:
+        """
+        List of all registered models.
+        :param filter_string: Filter query string, defaults to searching all registered models.
+        :param filter_tags: Filter tags, defaults not to filter any tags.
+        :return: A List of :py:class:`submarine.entities.model_registry.RegisteredModel` objects
+                that satisfy the search expressions.
+        """
+        conditions = []
+        if filter_tags is not None:
+            conditions = [
+                SqlRegisteredModel.registered_model_tags.any(
+                    SqlRegisteredModelTag.tag.contains(tag)
+                )
+                for tag in filter_tags
+            ]
+        if filter_str is not None:
+            conditions.append(SqlRegisteredModel.name.startswith(filter_str))
+        with self.ManagedSessionMaker() as session:
+            registered_models = session.query(SqlRegisteredModel).filter(*conditions).all()
+            return [
+                registered_model.to_submarine_entity() for registered_model in registered_models
+            ]
+
+    def get_registered_model(self, name: str) -> RegisteredModel:
+        """
+        Get registered model instance by name.
+        :param name: Registered model name.
+        :return: A single :py:class:`submarine.entities.model_registry.RegisteredModel` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            return self._get_registered_model(session, name, True).to_submarine_entity()
+
+    @classmethod
+    def _get_registered_model_tag(
+        cls, session: Session, name: str, tag: str
+    ) -> SqlRegisteredModelTag:
+        tags = (
+            session.query(SqlRegisteredModelTag)
+            .filter(SqlRegisteredModelTag.name == name, SqlRegisteredModelTag.tag == tag)
+            .all()
+        )
+        if len(tags) == 0:
+            raise SubmarineException(
+                message=f"Registered model tag with name={name}, tag={tag} not found"
+            )
+        elif len(tags) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model version tag with name={name}, tag={tag}. Found"
+                f" {len(tags)}."
+            )
+        else:
+            return tags[0]
+
+    def add_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Add a tag for the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            session.merge(SqlRegisteredModelTag(name=name, tag=tag))
+
+    def delete_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Delete a tag associated with the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            existing_tag = self._get_registered_model_tag(session, name, tag)
+            session.delete(existing_tag)
+
+    def create_model_version(
+        self,
+        name: str,
+        source: str,
+        user_id: str,
+        experiment_id: str,
+        dataset: str = None,
+        description: str = None,
+        tags: List[str] = None,
+    ) -> ModelVersion:
+        """
+        Create a new model version
+        :param name: Registered model name.
+        :param user_id: User ID from server that created this model
+        :param experiment_id: Experiment ID which this model is created.
+        :param source: Source path where this model is stored.
+        :param dataset: Dataset which this model is used.
+        :param description: Description of the version.
+        :param tags: A list of string associated with this model.
+        :return: A single object of :py:class:`submarine.entities.model_registry.ModelVersion`
+                 created in the backend.
+        """
+
+        def next_version(sql_registered_model: SqlRegisteredModel) -> int:
+            if sql_registered_model.model_versions:
+                return max([mv.version for mv in sql_registered_model.model_versions]) + 1
+            else:
+                return 1
+
+        validate_model_name(name)
+        validate_tags(tags)
+        with self.ManagedSessionMaker() as session:
+            try:
+                creation_time = datetime.now()
+                sql_registered_model = self._get_registered_model(session, name)
+                sql_registered_model.last_updated_time = creation_time
+                model_version = SqlModelVersion(
+                    name=name,
+                    version=next_version(sql_registered_model),
+                    source=source,
+                    user_id=user_id,
+                    experiment_id=experiment_id,
+                    creation_time=creation_time,
+                    last_updated_time=creation_time,
+                    dataset=dataset,
+                    description=description,
+                    model_tags=[SqlModelTag(tag=tag) for tag in tags or []],
+                )
+                self._save_to_db(session, [sql_registered_model, model_version])
+                session.flush()
+                return model_version.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError:
+                raise SubmarineException(message=f"Model Version creation error (name={name}).")
+
+    @classmethod
+    def _get_model_version(
+        cls, session: Session, name: str, version: int, eager: bool = False
+    ) -> SqlModelVersion:
+        """
+        :param eager: If ``True``, eagerly loads the model version's tags.
+                      If ``False``, these attributes are not eagerly loaded and
+                      will be loaded when their corresponding object properties
+                      are accessed from the resulting ``SqlModelVersion`` object.
+        """
+        validate_model_name(name)
+        validate_model_version(version)
+        query_options = cls._get_eager_model_version_query_options() if eager else []
+        conditions = [
+            SqlModelVersion.name == name,
+            SqlModelVersion.version == version,
+            SqlModelVersion.current_stage != STAGE_DELETED_INTERNAL,
+        ]
+
+        versions: List[SqlModelVersion] = (
+            session.query(SqlModelVersion).options(*query_options).filter(*conditions).all()
+        )
+        if len(versions) == 0:
+            raise SubmarineException(f"Model Version (name={name}, version={version}) not found.")
+        elif len(versions) > 1:
+            raise SubmarineException(
+                f"Expected only 1 model version with (name={name}, version={versions}). Found"
+                f" {len(versions)}."
+            )
+        else:
+            return versions[0]
+
+    def update_model_version_description(
+        self, name: str, version: int, description: str
+    ) -> ModelVersion:
+        """
+        Update description associated with a model version in backend.
+        :param name: Registered model name.
+        :param version: Registered model version.
+        :param description: New model description.
+        :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            update_time = datetime.now()
+            sql_model_version = self._get_model_version(session, name, version)
+            sql_model_version.description = description
+            sql_model_version.last_updated_time = update_time
+            self._save_to_db(session, sql_model_version)
+            return sql_model_version.to_submarine_entity()
+
+    def transition_model_version_stage(self, name: str, version: int, stage: str) -> ModelVersion:
+        """
+        Update model version stage.
+        :param name: Registered model name.
+        :param version: Registered model version.
+        :param stage: New desired stage for this model version.
+        :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            last_updated_time = datetime.now()
+
+            sql_model_version = self._get_model_version(session, name, version)
+            sql_model_version.current_stage = get_canonical_stage(stage)
+            sql_model_version.last_updated_time = last_updated_time
+            sql_registered_model = sql_model_version.registered_model
+            sql_registered_model.last_updated_time = last_updated_time
+            self._save_to_db(session, [sql_model_version, sql_registered_model])
+            return sql_model_version.to_submarine_entity()
+
+    def delete_model_version(self, name: str, version: int) -> None:
+        """
+        Delete model version in backend.
+        :param name: Registered model name.
+        :param version: Registered model version.
+        :return: None
+        """
+        with self.ManagedSessionMaker() as session:
+            updated_time = datetime.now()
+            sql_model_version = self._get_model_version(session, name, version)
+            sql_registered_model = sql_model_version.registered_model
+            sql_registered_model.last_updated_time = updated_time
+            session.delete(sql_model_version)
+            self._save_to_db(session, sql_registered_model)
+            session.flush()
+
+    def get_model_version(self, name: str, version: int) -> ModelVersion:
+        """
+        Get the model version instance by name and version.
+        :param name: Registered model name.
+        :param version: Registered model version.
+        :return: A single :py:class:`submarine.entities.model_registry.ModelVersion` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            sql_model_version = self._get_model_version(session, name, version, True)
+            return sql_model_version.to_submarine_entity()
+
+    def list_model_version(self, name: str, filter_tags: List[str] = None) -> List[ModelVersion]:
+        """
+        List of all model versions that satisfy the filter criteria.
+        :param name: Registered model name.
+        :param filter_tags: Filter tags, defaults not to filter any tags.
+        :return: A List of :py:class:`submarine.entities.model_registry.ModelVersion` objects
+                that satisfy the search expressions.
+        """
+        conditions = [SqlModelVersion.name == name]
+        if filter_tags is not None:
+            conditions = [
+                SqlModelVersion.model_tags.any(SqlModelTag.tag.contains(tag)) for tag in filter_tags

Review comment:
       Is it wrong to reassign "conditions" when filter_tags is not None?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] pingsutw commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
pingsutw commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r711610888



##########
File path: submarine-sdk/pysubmarine/submarine/utils/db_utils.py
##########
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from submarine.store import DEFAULT_SUBMARINE_JDBC_URL
+from submarine.utils import env
+
+_DB_URI_ENV_VAR = "SUBMARINE_DB_URI"
+
+
+_db_uri = None
+
+
+def is_db_uri_set():
+    """Returns True if the DB URI has been set, False otherwise."""
+    if _db_uri or env.get_env(_DB_URI_ENV_VAR):
+        return True
+    return False
+
+
+def set_db_uri(uri):
+    """
+    Set the DB URI. This does not affect the currently active run (if one exists),
+    but takes effect for successive runs.
+    """
+    global _tracking_uri
+    _tracking_uri = uri
+
+
+def get_db_uri():
+    """
+    Get the current DB URI.
+    :return: The tracking URI.

Review comment:
       ```suggestion
       :return: The DB URI.
   ```

##########
File path: submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
##########
@@ -0,0 +1,572 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from contextlib import contextmanager
+from datetime import datetime
+from typing import List, Union
+
+import sqlalchemy
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.orm.session import Session, sessionmaker
+from sqlalchemy.orm.strategy_options import _UnboundLoad
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+from submarine.entities.model_registry.model_version_stages import (
+    STAGE_DELETED_INTERNAL,
+    get_canonical_stage,
+)
+from submarine.exceptions import SubmarineException
+from submarine.store.database.models import (
+    Base,
+    SqlModelTag,
+    SqlModelVersion,
+    SqlRegisteredModel,
+    SqlRegisteredModelTag,
+)
+from submarine.store.model_registry.abstract_store import AbstractStore
+from submarine.utils import extract_db_type_from_uri
+from submarine.utils.validation import (
+    validate_description,
+    validate_model_name,
+    validate_model_version,
+    validate_tag,
+    validate_tags,
+)
+
+_logger = logging.getLogger(__name__)
+
+
+class SqlAlchemyStore(AbstractStore):
+    def __init__(self, db_uri: str) -> None:
+        """
+        Create a database backed store.
+        :param db_uri: The SQLAlchemy database URI string to connect to the database. See
+                       the `SQLAlchemy docs
+                       <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
+                       for format specifications. Submarine supports the dialects ``mysql``.
+        """
+        super(SqlAlchemyStore, self).__init__()
+
+        self.db_uri = db_uri
+        self.db_type = extract_db_type_from_uri(db_uri)
+        self.engine = sqlalchemy.create_engine(db_uri, pool_pre_ping=True)
+        insp = sqlalchemy.inspect(self.engine)
+
+        # Verify that all model registry tables exist.
+        expected_tables = {
+            SqlModelVersion.__tablename__,
+            SqlModelTag.__tablename__,
+            SqlRegisteredModel.__tablename__,
+            SqlRegisteredModelTag.__tablename__,
+        }
+        if len(expected_tables & set(insp.get_table_names())) == 0:
+            SqlAlchemyStore._initialize_tables(self.engine)
+        Base.metadata.bind = self.engine
+        SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
+        self.ManagedSessionMaker = self._get_managed_session_maker(SessionMaker)
+
+    @staticmethod
+    def _initialize_tables(engine: Engine):
+        _logger.info("Creating initial Submarine database tables...")
+        Base.metadata.create_all(engine)
+
+    @staticmethod
+    def _get_managed_session_maker(SessionMaker: sessionmaker):
+        """
+        Creates a factory for producing exception-safe SQLAlchemy sessions that are made available
+        using a context manager. Any session produced by this factory is automatically committed
+        if no exceptions are encountered within its associated context. If an exception is
+        encountered, the session is rolled back. Finally, any session produced by this factory is
+        automatically closed when the session's associated context is exited.
+        """
+
+        @contextmanager
+        def make_managed_session():
+            """Provide a transactional scope around a series of operations."""
+            session: Session = SessionMaker()
+            try:
+                yield session
+                session.commit()
+            except SubmarineException:
+                session.rollback()
+                raise
+            except Exception as e:
+                session.rollback()
+                raise SubmarineException(e)
+            finally:
+                session.close()
+
+        return make_managed_session
+
+    @staticmethod
+    def _get_eager_registered_model_query_options() -> List[_UnboundLoad]:
+        """
+        :return A list of SQLAlchemy query options that can be used to eagerly
+                load the following registered model attributes
+                when fetching a registered model: ``registered_model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlRegisteredModel.registered_model_tags)]
+
+    @staticmethod
+    def _get_eager_model_version_query_options():
+        """
+        :return: A list of SQLAlchemy query options that can be used to eagerly
+                load the following model version attributes
+                when fetching a model version: ``model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlModelVersion.model_tags)]
+
+    def _save_to_db(self, session: Session, objs: Union[list, object]) -> None:
+        """
+        Store in db
+        """
+        if type(objs) is list:
+            session.add_all(objs)
+        else:
+            # single object
+            session.add(objs)
+
+    def create_registered_model(
+        self, name: str, description: str = None, tags: List[str] = None
+    ) -> RegisteredModel:
+        """
+        Create a new registered model in backend store.
+        :param name: Name of the new model. This is expected to be unique in the backend store.
+        :param description: Description of the model.
+        :param tags: A list of string associated with this registered model.
+        :return: A single object of :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 created in the backend.
+        """
+        validate_model_name(name)
+        validate_tags(tags)
+        validate_description(description)
+        with self.ManagedSessionMaker() as session:
+            try:
+                creation_time = datetime.now()
+                registered_model = SqlRegisteredModel(
+                    name=name,
+                    creation_time=creation_time,
+                    last_updated_time=creation_time,
+                    description=description,
+                    registered_model_tags=[SqlRegisteredModelTag(tag=tag) for tag in tags or []],
+                )
+                self._save_to_db(session, registered_model)
+                session.flush()
+                return registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    f"Registered Model (name={name}) already exists.\nError: {str(e)}"
+                )
+
+    @classmethod
+    def _get_registered_model(
+        cls, session: Session, name: str, eager: bool = False
+    ) -> SqlRegisteredModel:
+        """
+        :param eager: If ``True``, eagerly loads the registered model's tags.
+                      If ``False``, these attributes are not eagerly loaded and
+                      will be loaded when their corresponding object properties
+                      are accessed from the resulting ``SqlRegisteredModel`` object.
+        """
+        validate_model_name(name)
+        query_options = cls._get_eager_registered_model_query_options() if eager else []
+        models: List[SqlRegisteredModel] = (
+            session.query(SqlRegisteredModel)
+            .options(*query_options)
+            .filter(SqlRegisteredModel.name == name)
+            .all()
+        )
+
+        if len(models) == 0:
+            raise SubmarineException(f"Registered Model with name={name} not found")
+        elif len(models) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model with name={name}.\nFound {len(models)}"
+            )
+        else:
+            return models[0]
+
+    def update_registered_model_description(self, name: str, description: str) -> RegisteredModel:
+        """
+        Update description of the registered model.
+        :param name: Registered model name.
+        :param description: New description.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        validate_description(description)
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            sql_registered_model.description = description
+            sql_registered_model.last_updated_time = datetime.now()
+            self._save_to_db(session, sql_registered_model)
+            session.flush()
+            return sql_registered_model.to_submarine_entity()
+
+    def rename_registered_model(self, name: str, new_name: str) -> RegisteredModel:
+        """
+        Rename the registered model.
+        :param name: Registered model name.
+        :param new_name: New proposed name.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        validate_model_name(new_name)
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            try:
+                update_time = datetime.now()
+                sql_registered_model.name = new_name
+                sql_registered_model.last_updated_time = update_time
+                for sql_model_version in sql_registered_model.model_versions:
+                    sql_model_version.name = new_name
+                    sql_model_version.last_updated_time = update_time
+                self._save_to_db(
+                    session, [sql_registered_model] + sql_registered_model.model_versions
+                )
+                session.flush()
+                return sql_registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    f"Registered Model (name={name}) already exists. Error: {str(e)}"
+                )
+
+    def delete_registered_model(self, name: str) -> None:
+        """
+        Delete the registered model.
+        :param name: Registered model name.
+        :return: None
+        """
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            session.delete(sql_registered_model)
+
+    def list_registered_model(
+        self, filter_str: str = None, filter_tags: List[str] = None
+    ) -> List[RegisteredModel]:
+        """
+        List of all registered models.
+        :param filter_string: Filter query string, defaults to searching all registered models.
+        :param filter_tags: Filter tags, defaults not to filter any tags.
+        :return: A List of :py:class:`submarine.entities.model_registry.RegisteredModel` objects
+                that satisfy the search expressions.
+        """
+        conditions = []
+        if filter_tags is not None:
+            conditions = [
+                SqlRegisteredModel.registered_model_tags.any(
+                    SqlRegisteredModelTag.tag.contains(tag)
+                )
+                for tag in filter_tags
+            ]
+        if filter_str is not None:
+            conditions.append(SqlRegisteredModel.name.startswith(filter_str))
+        with self.ManagedSessionMaker() as session:
+            registered_models = session.query(SqlRegisteredModel).filter(*conditions).all()
+            return [
+                registered_model.to_submarine_entity() for registered_model in registered_models
+            ]
+
+    def get_registered_model(self, name: str) -> RegisteredModel:
+        """
+        Get registered model instance by name.
+        :param name: Registered model name.
+        :return: A single :py:class:`submarine.entities.model_registry.RegisteredModel` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            return self._get_registered_model(session, name, True).to_submarine_entity()
+
+    @classmethod
+    def _get_registered_model_tag(
+        cls, session: Session, name: str, tag: str
+    ) -> SqlRegisteredModelTag:
+        tags = (
+            session.query(SqlRegisteredModelTag)
+            .filter(SqlRegisteredModelTag.name == name, SqlRegisteredModelTag.tag == tag)
+            .all()
+        )
+        if len(tags) == 0:
+            raise SubmarineException(f"Registered model tag with name={name}, tag={tag} not found")
+        elif len(tags) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model version tag with name={name}, tag={tag}. Found"
+                f" {len(tags)}."
+            )
+        else:
+            return tags[0]
+
+    def add_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Add a tag for the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            session.merge(SqlRegisteredModelTag(name=name, tag=tag))
+
+    def delete_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Delete a tag associated with the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            existing_tag = self._get_registered_model_tag(session, name, tag)
+            session.delete(existing_tag)
+
+    def create_model_version(
+        self,
+        name: str,
+        source: str,
+        user_id: str,
+        experiment_id: str,
+        dataset: str = None,
+        description: str = None,
+        tags: List[str] = None,

Review comment:
       What's different between the tags in `model version` and `registered model`?

##########
File path: submarine-sdk/pysubmarine/submarine/utils/db_utils.py
##########
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from submarine.store import DEFAULT_SUBMARINE_JDBC_URL
+from submarine.utils import env
+
+_DB_URI_ENV_VAR = "SUBMARINE_DB_URI"
+
+
+_db_uri = None
+
+
+def is_db_uri_set():
+    """Returns True if the DB URI has been set, False otherwise."""
+    if _db_uri or env.get_env(_DB_URI_ENV_VAR):
+        return True
+    return False
+
+
+def set_db_uri(uri):
+    """
+    Set the DB URI. This does not affect the currently active run (if one exists),
+    but takes effect for successive runs.
+    """
+    global _tracking_uri
+    _tracking_uri = uri

Review comment:
       should be `_db_uri`?

##########
File path: submarine-sdk/pysubmarine/github-actions/test-requirements.txt
##########
@@ -23,11 +23,12 @@ pytest==3.2.1
 pytest-cov==2.6.0
 pytest-localserver==0.5.0
 pylint==2.5.2
-sqlalchemy==1.3.0
+sqlalchemy >= 1.4.0

Review comment:
       We could remove it, right? we already have it in setup.py

##########
File path: submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
##########
@@ -0,0 +1,234 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABCMeta, abstractmethod
+from typing import List
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+
+
+class AbstractStore:
+    """
+    Abstract class for Backend model registry
+    This class defines the API interface for frontends to connect with various types of backends.
+    """
+
+    __metaclass__ = ABCMeta
+
+    def __init__(self) -> None:
+        """
+        Empty constructor for now. This is deliberately not marked as abstract, else every
+        derived class would be forced to create one.
+        """
+        pass
+
+    @abstractmethod
+    def create_registered_model(

Review comment:
       Do we need to put these functions here? https://github.com/apache/submarine/blob/master/submarine-sdk/pysubmarine/submarine/store/abstract_store.py

##########
File path: submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
##########
@@ -0,0 +1,572 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from contextlib import contextmanager
+from datetime import datetime
+from typing import List, Union
+
+import sqlalchemy
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.orm.session import Session, sessionmaker
+from sqlalchemy.orm.strategy_options import _UnboundLoad
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+from submarine.entities.model_registry.model_version_stages import (
+    STAGE_DELETED_INTERNAL,
+    get_canonical_stage,
+)
+from submarine.exceptions import SubmarineException
+from submarine.store.database.models import (
+    Base,
+    SqlModelTag,
+    SqlModelVersion,
+    SqlRegisteredModel,
+    SqlRegisteredModelTag,
+)
+from submarine.store.model_registry.abstract_store import AbstractStore
+from submarine.utils import extract_db_type_from_uri
+from submarine.utils.validation import (
+    validate_description,
+    validate_model_name,
+    validate_model_version,
+    validate_tag,
+    validate_tags,
+)
+
+_logger = logging.getLogger(__name__)
+
+
+class SqlAlchemyStore(AbstractStore):
+    def __init__(self, db_uri: str) -> None:
+        """
+        Create a database backed store.
+        :param db_uri: The SQLAlchemy database URI string to connect to the database. See
+                       the `SQLAlchemy docs
+                       <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
+                       for format specifications. Submarine supports the dialects ``mysql``.
+        """
+        super(SqlAlchemyStore, self).__init__()
+
+        self.db_uri = db_uri
+        self.db_type = extract_db_type_from_uri(db_uri)
+        self.engine = sqlalchemy.create_engine(db_uri, pool_pre_ping=True)
+        insp = sqlalchemy.inspect(self.engine)
+
+        # Verify that all model registry tables exist.
+        expected_tables = {
+            SqlModelVersion.__tablename__,
+            SqlModelTag.__tablename__,
+            SqlRegisteredModel.__tablename__,
+            SqlRegisteredModelTag.__tablename__,
+        }
+        if len(expected_tables & set(insp.get_table_names())) == 0:
+            SqlAlchemyStore._initialize_tables(self.engine)
+        Base.metadata.bind = self.engine
+        SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
+        self.ManagedSessionMaker = self._get_managed_session_maker(SessionMaker)
+
+    @staticmethod
+    def _initialize_tables(engine: Engine):
+        _logger.info("Creating initial Submarine database tables...")
+        Base.metadata.create_all(engine)
+
+    @staticmethod
+    def _get_managed_session_maker(SessionMaker: sessionmaker):
+        """
+        Creates a factory for producing exception-safe SQLAlchemy sessions that are made available
+        using a context manager. Any session produced by this factory is automatically committed
+        if no exceptions are encountered within its associated context. If an exception is
+        encountered, the session is rolled back. Finally, any session produced by this factory is
+        automatically closed when the session's associated context is exited.
+        """
+
+        @contextmanager
+        def make_managed_session():
+            """Provide a transactional scope around a series of operations."""
+            session: Session = SessionMaker()
+            try:
+                yield session
+                session.commit()
+            except SubmarineException:
+                session.rollback()
+                raise
+            except Exception as e:
+                session.rollback()
+                raise SubmarineException(e)
+            finally:
+                session.close()
+
+        return make_managed_session
+
+    @staticmethod
+    def _get_eager_registered_model_query_options() -> List[_UnboundLoad]:
+        """
+        :return A list of SQLAlchemy query options that can be used to eagerly
+                load the following registered model attributes
+                when fetching a registered model: ``registered_model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlRegisteredModel.registered_model_tags)]
+
+    @staticmethod
+    def _get_eager_model_version_query_options():
+        """
+        :return: A list of SQLAlchemy query options that can be used to eagerly
+                load the following model version attributes
+                when fetching a model version: ``model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlModelVersion.model_tags)]
+
+    def _save_to_db(self, session: Session, objs: Union[list, object]) -> None:
+        """
+        Store in db
+        """
+        if type(objs) is list:
+            session.add_all(objs)
+        else:
+            # single object
+            session.add(objs)
+
+    def create_registered_model(

Review comment:
       ditto.
   https://github.com/apache/submarine/blob/master/submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py

##########
File path: submarine-sdk/pysubmarine/submarine/utils/validation.py
##########
@@ -116,6 +117,45 @@ def validate_param(key, value):
     _validate_length_limit("Param value", MAX_PARAM_VAL_LENGTH, str(value))
 
 
+def validate_tags(tags: Optional[List[str]]) -> None:
+    if tags is not None and not isinstance(tags, list):
+        raise SubmarineException("parameter tags must be list or None.")
+    for tag in tags or []:
+        validate_tag(tag)
+
+
+def validate_tag(tag: str) -> None:
+    """Check that `tag` is a valid tag value and raise an exception if it isn't."""
+    # Reuse param & metric check.
+    if tag is None or tag == "":
+        raise SubmarineException("Tag cannot be empty.")
+    if not _VALID_PARAM_AND_METRIC_NAMES.match(tag):
+        raise SubmarineException("Invalid tag name: '%s'. %s" % (tag, _BAD_CHARACTERS_MESSAGE))

Review comment:
       ```suggestion
           raise SubmarineException(f"Invalid tag name: {tag}. {_BAD_CHARACTERS_MESSAGE}")
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] KUAN-HSUN-LI commented on pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
KUAN-HSUN-LI commented on pull request #752:
URL: https://github.com/apache/submarine/pull/752#issuecomment-921415492


   @pingsutw @jeff-901  Can you help me review the code? Thanks


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] KUAN-HSUN-LI commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
KUAN-HSUN-LI commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r711660803



##########
File path: submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
##########
@@ -0,0 +1,572 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from contextlib import contextmanager
+from datetime import datetime
+from typing import List, Union
+
+import sqlalchemy
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.orm.session import Session, sessionmaker
+from sqlalchemy.orm.strategy_options import _UnboundLoad
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+from submarine.entities.model_registry.model_version_stages import (
+    STAGE_DELETED_INTERNAL,
+    get_canonical_stage,
+)
+from submarine.exceptions import SubmarineException
+from submarine.store.database.models import (
+    Base,
+    SqlModelTag,
+    SqlModelVersion,
+    SqlRegisteredModel,
+    SqlRegisteredModelTag,
+)
+from submarine.store.model_registry.abstract_store import AbstractStore
+from submarine.utils import extract_db_type_from_uri
+from submarine.utils.validation import (
+    validate_description,
+    validate_model_name,
+    validate_model_version,
+    validate_tag,
+    validate_tags,
+)
+
+_logger = logging.getLogger(__name__)
+
+
+class SqlAlchemyStore(AbstractStore):
+    def __init__(self, db_uri: str) -> None:
+        """
+        Create a database backed store.
+        :param db_uri: The SQLAlchemy database URI string to connect to the database. See
+                       the `SQLAlchemy docs
+                       <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
+                       for format specifications. Submarine supports the dialects ``mysql``.
+        """
+        super(SqlAlchemyStore, self).__init__()
+
+        self.db_uri = db_uri
+        self.db_type = extract_db_type_from_uri(db_uri)
+        self.engine = sqlalchemy.create_engine(db_uri, pool_pre_ping=True)
+        insp = sqlalchemy.inspect(self.engine)
+
+        # Verify that all model registry tables exist.
+        expected_tables = {
+            SqlModelVersion.__tablename__,
+            SqlModelTag.__tablename__,
+            SqlRegisteredModel.__tablename__,
+            SqlRegisteredModelTag.__tablename__,
+        }
+        if len(expected_tables & set(insp.get_table_names())) == 0:
+            SqlAlchemyStore._initialize_tables(self.engine)
+        Base.metadata.bind = self.engine
+        SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
+        self.ManagedSessionMaker = self._get_managed_session_maker(SessionMaker)
+
+    @staticmethod
+    def _initialize_tables(engine: Engine):
+        _logger.info("Creating initial Submarine database tables...")
+        Base.metadata.create_all(engine)
+
+    @staticmethod
+    def _get_managed_session_maker(SessionMaker: sessionmaker):
+        """
+        Creates a factory for producing exception-safe SQLAlchemy sessions that are made available
+        using a context manager. Any session produced by this factory is automatically committed
+        if no exceptions are encountered within its associated context. If an exception is
+        encountered, the session is rolled back. Finally, any session produced by this factory is
+        automatically closed when the session's associated context is exited.
+        """
+
+        @contextmanager
+        def make_managed_session():
+            """Provide a transactional scope around a series of operations."""
+            session: Session = SessionMaker()
+            try:
+                yield session
+                session.commit()
+            except SubmarineException:
+                session.rollback()
+                raise
+            except Exception as e:
+                session.rollback()
+                raise SubmarineException(e)
+            finally:
+                session.close()
+
+        return make_managed_session
+
+    @staticmethod
+    def _get_eager_registered_model_query_options() -> List[_UnboundLoad]:
+        """
+        :return A list of SQLAlchemy query options that can be used to eagerly
+                load the following registered model attributes
+                when fetching a registered model: ``registered_model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlRegisteredModel.registered_model_tags)]
+
+    @staticmethod
+    def _get_eager_model_version_query_options():
+        """
+        :return: A list of SQLAlchemy query options that can be used to eagerly
+                load the following model version attributes
+                when fetching a model version: ``model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlModelVersion.model_tags)]
+
+    def _save_to_db(self, session: Session, objs: Union[list, object]) -> None:
+        """
+        Store in db
+        """
+        if type(objs) is list:
+            session.add_all(objs)
+        else:
+            # single object
+            session.add(objs)
+
+    def create_registered_model(
+        self, name: str, description: str = None, tags: List[str] = None
+    ) -> RegisteredModel:
+        """
+        Create a new registered model in backend store.
+        :param name: Name of the new model. This is expected to be unique in the backend store.
+        :param description: Description of the model.
+        :param tags: A list of string associated with this registered model.
+        :return: A single object of :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 created in the backend.
+        """
+        validate_model_name(name)
+        validate_tags(tags)
+        validate_description(description)
+        with self.ManagedSessionMaker() as session:
+            try:
+                creation_time = datetime.now()
+                registered_model = SqlRegisteredModel(
+                    name=name,
+                    creation_time=creation_time,
+                    last_updated_time=creation_time,
+                    description=description,
+                    registered_model_tags=[SqlRegisteredModelTag(tag=tag) for tag in tags or []],
+                )
+                self._save_to_db(session, registered_model)
+                session.flush()
+                return registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    f"Registered Model (name={name}) already exists.\nError: {str(e)}"
+                )
+
+    @classmethod
+    def _get_registered_model(
+        cls, session: Session, name: str, eager: bool = False
+    ) -> SqlRegisteredModel:
+        """
+        :param eager: If ``True``, eagerly loads the registered model's tags.
+                      If ``False``, these attributes are not eagerly loaded and
+                      will be loaded when their corresponding object properties
+                      are accessed from the resulting ``SqlRegisteredModel`` object.
+        """
+        validate_model_name(name)
+        query_options = cls._get_eager_registered_model_query_options() if eager else []
+        models: List[SqlRegisteredModel] = (
+            session.query(SqlRegisteredModel)
+            .options(*query_options)
+            .filter(SqlRegisteredModel.name == name)
+            .all()
+        )
+
+        if len(models) == 0:
+            raise SubmarineException(f"Registered Model with name={name} not found")
+        elif len(models) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model with name={name}.\nFound {len(models)}"
+            )
+        else:
+            return models[0]
+
+    def update_registered_model_description(self, name: str, description: str) -> RegisteredModel:
+        """
+        Update description of the registered model.
+        :param name: Registered model name.
+        :param description: New description.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        validate_description(description)
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            sql_registered_model.description = description
+            sql_registered_model.last_updated_time = datetime.now()
+            self._save_to_db(session, sql_registered_model)
+            session.flush()
+            return sql_registered_model.to_submarine_entity()
+
+    def rename_registered_model(self, name: str, new_name: str) -> RegisteredModel:
+        """
+        Rename the registered model.
+        :param name: Registered model name.
+        :param new_name: New proposed name.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        validate_model_name(new_name)
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            try:
+                update_time = datetime.now()
+                sql_registered_model.name = new_name
+                sql_registered_model.last_updated_time = update_time
+                for sql_model_version in sql_registered_model.model_versions:
+                    sql_model_version.name = new_name
+                    sql_model_version.last_updated_time = update_time
+                self._save_to_db(
+                    session, [sql_registered_model] + sql_registered_model.model_versions
+                )
+                session.flush()
+                return sql_registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    f"Registered Model (name={name}) already exists. Error: {str(e)}"
+                )
+
+    def delete_registered_model(self, name: str) -> None:
+        """
+        Delete the registered model.
+        :param name: Registered model name.
+        :return: None
+        """
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            session.delete(sql_registered_model)
+
+    def list_registered_model(
+        self, filter_str: str = None, filter_tags: List[str] = None
+    ) -> List[RegisteredModel]:
+        """
+        List of all registered models.
+        :param filter_string: Filter query string, defaults to searching all registered models.
+        :param filter_tags: Filter tags, defaults not to filter any tags.
+        :return: A List of :py:class:`submarine.entities.model_registry.RegisteredModel` objects
+                that satisfy the search expressions.
+        """
+        conditions = []
+        if filter_tags is not None:
+            conditions = [
+                SqlRegisteredModel.registered_model_tags.any(
+                    SqlRegisteredModelTag.tag.contains(tag)
+                )
+                for tag in filter_tags
+            ]
+        if filter_str is not None:
+            conditions.append(SqlRegisteredModel.name.startswith(filter_str))
+        with self.ManagedSessionMaker() as session:
+            registered_models = session.query(SqlRegisteredModel).filter(*conditions).all()
+            return [
+                registered_model.to_submarine_entity() for registered_model in registered_models
+            ]
+
+    def get_registered_model(self, name: str) -> RegisteredModel:
+        """
+        Get registered model instance by name.
+        :param name: Registered model name.
+        :return: A single :py:class:`submarine.entities.model_registry.RegisteredModel` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            return self._get_registered_model(session, name, True).to_submarine_entity()
+
+    @classmethod
+    def _get_registered_model_tag(
+        cls, session: Session, name: str, tag: str
+    ) -> SqlRegisteredModelTag:
+        tags = (
+            session.query(SqlRegisteredModelTag)
+            .filter(SqlRegisteredModelTag.name == name, SqlRegisteredModelTag.tag == tag)
+            .all()
+        )
+        if len(tags) == 0:
+            raise SubmarineException(f"Registered model tag with name={name}, tag={tag} not found")
+        elif len(tags) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model version tag with name={name}, tag={tag}. Found"
+                f" {len(tags)}."
+            )
+        else:
+            return tags[0]
+
+    def add_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Add a tag for the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            session.merge(SqlRegisteredModelTag(name=name, tag=tag))
+
+    def delete_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Delete a tag associated with the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            existing_tag = self._get_registered_model_tag(session, name, tag)
+            session.delete(existing_tag)
+
+    def create_model_version(
+        self,
+        name: str,
+        source: str,
+        user_id: str,
+        experiment_id: str,
+        dataset: str = None,
+        description: str = None,
+        tags: List[str] = None,

Review comment:
       `registered model tag` focuses on what is the goal of this model or the category of this model. `model version tag` focuses on what this model perform or feature of this model. I think it is a little confused in these two parts. Hopefully, this explanation can help.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] KUAN-HSUN-LI commented on pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
KUAN-HSUN-LI commented on pull request #752:
URL: https://github.com/apache/submarine/pull/752#issuecomment-927156420


   @pingsutw @jeff-901 I am so sorry for changing the table name in the database again. I think the previous names are the best. I will explain it as follow:
   workflow of model registry:
   1. save models in s3 bucket (unregistered model)
   2. choose the best model and registry it (registered model, version 1)
   3. train the model with different parameters or datasets and save models in s3 bucket (unregistered model)
   4. Again, choose the best model and registry it (registered model, version 2)
   
   Summary:
   * The table is storing registered models so the name `registered_model` is more proper.
   * The version of the registered model with the same name will define a different version which will be stored in another table. It is also proper to name the table `model_version`
   
   We also need two tag tables.
   example:
   ![image](https://user-images.githubusercontent.com/38066413/134780767-b80258fa-a901-4dd8-a693-0f9d69c24dfa.png)
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [submarine] jeff-901 commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

Posted by GitBox <gi...@apache.org>.
jeff-901 commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r711579353



##########
File path: submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py
##########
@@ -0,0 +1,731 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from datetime import datetime
+from typing import List
+
+import freezegun
+import pytest
+from freezegun import freeze_time
+
+import submarine
+from submarine.entities.model_registry.model_version import ModelVersion
+from submarine.entities.model_registry.model_version_stages import (
+    STAGE_ARCHIVED,
+    STAGE_NONE,
+    STAGE_PRODUCTION,
+    STAGE_STAGING,
+)
+from submarine.entities.model_registry.registered_model import RegisteredModel
+from submarine.exceptions import SubmarineException
+from submarine.store.database import models
+from submarine.store.model_registry.sqlalchemy_store import SqlAlchemyStore
+
+freezegun.configure(default_ignore_list=["threading", "tensorflow"])
+
+
+@pytest.mark.e2e
+class TestSqlAlchemyStore(unittest.TestCase):
+    def setUp(self):
+        submarine.set_db_uri(
+            "mysql+pymysql://submarine_test:password_test@localhost:3306/submarine_test"
+        )
+        self.db_uri = submarine.get_db_uri()
+        self.store = SqlAlchemyStore(self.db_uri)
+
+    def tearDown(self):
+        submarine.set_db_uri(None)
+        models.Base.metadata.drop_all(self.store.engine)
+
+    def test_create_registered_model(self):
+        name1 = "test_create_RM_1"
+        rm1 = self.store.create_registered_model(name1)
+        self.assertEqual(rm1.name, name1)
+        self.assertEqual(rm1.description, None)
+
+        # error in duplicate
+        with self.assertRaises(SubmarineException):
+            self.store.create_registered_model(name1)
+
+        # test create with tags
+        name2 = "test_create_RM_2"
+        tags = ["tag1", "tag2"]
+        rm2 = self.store.create_registered_model(name2, tags=tags)
+        rm2d = self.store.get_registered_model(name2)
+        self.assertEqual(rm2.name, name2)
+        self.assertEqual(rm2.tags, tags)
+        self.assertEqual(rm2d.name, name2)
+        self.assertEqual(rm2d.tags, tags)
+
+        # test create with description
+        name3 = "test_create_RM_3"
+        description = "A test description."
+        rm3 = self.store.create_registered_model(name3, description)
+        rmd3 = self.store.get_registered_model(name3)
+        self.assertEqual(rm3.name, name3)
+        self.assertEqual(rm3.description, description)
+        self.assertEqual(rmd3.name, name3)
+        self.assertEqual(rmd3.description, description)
+
+        # invalid model name
+        with self.assertRaises(SubmarineException):
+            self.store.create_registered_model(None)
+        with self.assertRaises(SubmarineException):
+            self.store.create_registered_model("")
+
+    def test_update_registered_model_discription(self):
+        name = "test_update_RM"
+        rm1 = self.store.create_registered_model(name)
+        rmd1 = self.store.get_registered_model(name)
+        self.assertEqual(rm1.name, name)
+        self.assertEqual(rmd1.description, None)
+
+        # update description
+        fake_datetime = datetime.strptime("2021-11-11 11:11:11.111000", "%Y-%m-%d %H:%M:%S.%f")
+        with freeze_time(fake_datetime):
+            rm2 = self.store.update_registered_model_discription(name, "New description.")
+            rm2d = self.store.get_registered_model(name)
+            self.assertEqual(rm2.name, name)
+            self.assertEqual(rm2.description, "New description.")
+            self.assertEqual(rm2d.name, name)
+            self.assertEqual(rm2d.description, "New description.")
+            self.assertEqual(rm2d.last_updated_time, fake_datetime)
+
+    def test_rename_registered_model(self):
+        name = "test_rename_RM"
+        new_name = "test_rename_RM_new"
+        rm = self.store.create_registered_model(name)
+        self.store.create_model_version(name, "path/to/source", "test", "application_1234")
+        self.store.create_model_version(name, "path/to/source", "test", "application_1235")
+        mvd1 = self.store.get_model_version(name, 1)
+        mvd2 = self.store.get_model_version(name, 2)
+        self.assertEqual(rm.name, name)
+        self.assertEqual(mvd1.name, name)
+        self.assertEqual(mvd2.name, name)
+
+        # test renaming registered model also updates its model versions
+        self.store.rename_registered_model(name, new_name)
+        rm = self.store.get_registered_model(new_name)
+        mv1 = self.store.get_model_version(new_name, 1)
+        mv2 = self.store.get_model_version(new_name, 2)
+        self.assertEqual(rm.name, new_name)
+        self.assertEqual(mv1.name, new_name)
+        self.assertEqual(mv2.name, new_name)
+
+        # test accessing the registered model with the original name will fail
+        with self.assertRaises(SubmarineException):
+            self.store.rename_registered_model(name, name)
+
+        # invalid name will fail
+        with self.assertRaises(SubmarineException):
+            self.store.rename_registered_model(name, None)
+        with self.assertRaises(SubmarineException):
+            self.store.rename_registered_model(name, "")
+
+    def test_delete_registered_model(self):
+        name1 = "test_delete_RM"
+        name2 = "test_delete_RM_2"
+        rm_tags = ["rm_tag1", "rm_tag2"]
+        rm1 = self.store.create_registered_model(name1, tags=rm_tags)
+        rm2 = self.store.create_registered_model(name2, tags=rm_tags)
+        mv_tags = ["mv_tag1", "mv_tag2"]
+        rm1mv1 = self.store.create_model_version(
+            rm1.name, "path/to/source", "test", "application_1234", tags=mv_tags
+        )
+        rm2mv1 = self.store.create_model_version(
+            rm2.name, "path/to/source", "test", "application_1234", tags=mv_tags
+        )
+
+        # check store
+        rmd1 = self.store.get_registered_model(rm1.name)
+        self.assertEqual(rmd1.name, name1)
+        self.assertEqual(rmd1.tags, rm_tags)
+        rm1mv1d = self.store.get_model_version(rm1mv1.name, rm1mv1.version)
+        self.assertEqual(rm1mv1d.name, name1)
+        self.assertEqual(rm1mv1d.tags, mv_tags)
+
+        # delete registered model
+        self.store.delete_registered_model(rm1.name)
+
+        # cannot get model
+        with self.assertRaises(SubmarineException):
+            self.store.get_registered_model(rm1.name)
+
+        # cannot delete it again
+        with self.assertRaises(SubmarineException):
+            self.store.delete_registered_model(rm1.name)
+
+        # registered model tag are cascade deleted with the registered model
+        for tag in rm_tags:
+            with self.assertRaises(SubmarineException):
+                self.store.delete_registered_model_tag(rm1.name, tag)
+
+        # model versions are cascade deleted with the registered model
+        with self.assertRaises(SubmarineException):
+            self.store.get_model_version(rm1mv1.name, rm1mv1.version)
+
+        # model tags are cascade deleted with the registered model
+        for tag in mv_tags:
+            with self.assertRaises(SubmarineException):
+                self.store.delete_model_tag(rm1mv1.name, rm1mv1.version, tag)
+
+        # Other registered model and model version is not affected
+        rm2d = self.store.get_registered_model(rm2.name)
+        self.assertEqual(rm2d.name, rm2.name)
+        self.assertEqual(rm2d.tags, rm2.tags)
+        rm2mv1d = self.store.get_model_version(rm2mv1.name, rm2mv1.version)
+        self.assertEqual(rm2mv1d.name, rm2mv1.name)
+        self.assertEqual(rm2mv1d.tags, rm2mv1.tags)
+
+    def _compare_registered_model_names(
+        self, results: List[RegisteredModel], rms: List[RegisteredModel]
+    ):
+        result_names = set([result.name for result in results])
+        rms_names = set([rm.name for rm in rms])
+
+        self.assertEqual(result_names, rms_names)
+
+    def test_list_registered_model(self):
+        rms = [self.store.create_registered_model(f"test_list_RM_{i}") for i in range(10)]
+
+        results = self.store.list_registered_model()
+        self.assertEqual(len(results), 10)
+        self._compare_registered_model_names(results, rms)
+
+    def test_list_registered_model_filter_with_string(self):
+        rms = [
+            self.store.create_registered_model("A"),
+            self.store.create_registered_model("AB"),
+            self.store.create_registered_model("B"),
+            self.store.create_registered_model("ABA"),
+            self.store.create_registered_model("AAA"),
+        ]
+
+        results = self.store.list_registered_model(filter_str="A")
+        self.assertEqual(len(results), 4)
+        self._compare_registered_model_names(rms[:2] + rms[3:], results)
+
+        results = self.store.list_registered_model(filter_str="AB")
+        self.assertEqual(len(results), 2)
+        self._compare_registered_model_names([rms[1], rms[3]], results)
+
+        results = self.store.list_registered_model(filter_str="ABA")
+        self.assertEqual(len(results), 1)
+        self._compare_registered_model_names([rms[3]], results)
+
+        results = self.store.list_registered_model(filter_str="ABC")
+        self.assertEqual(len(results), 0)
+        self.assertEqual(results, [])
+
+    def test_list_registered_model_filter_with_tags(self):
+        tags = ["tag1", "tag2", "tag3"]
+        rms = [
+            self.store.create_registered_model("test1"),
+            self.store.create_registered_model("test2", tags=tags[0:1]),
+            self.store.create_registered_model("test3", tags=tags[1:2]),
+            self.store.create_registered_model("test4", tags=[tags[0], tags[2]]),
+            self.store.create_registered_model("test5", tags=tags),
+        ]
+
+        results = self.store.list_registered_model(filter_tags=tags[0:1])
+        self.assertEqual(len(results), 3)
+        self._compare_registered_model_names(results, [rms[1], rms[3], rms[4]])
+
+        results = self.store.list_registered_model(filter_tags=tags[0:2])
+        self.assertEqual(len(results), 1)
+        self._compare_registered_model_names(results, [rms[-1]])
+
+        # empty result
+        other_tag = ["tag4"]
+        results = self.store.list_registered_model(filter_tags=other_tag)
+        self.assertEqual(len(results), 0)
+        self.assertEqual(results, [])
+
+        # empty result
+        results = self.store.list_registered_model(filter_tags=tags + other_tag)
+        self.assertEqual(len(results), 0)
+        self.assertEqual(results, [])
+
+    def test_list_registered_model_filter_both(self):
+        tags = ["tag1", "tag2", "tag3"]
+        rms = [
+            self.store.create_registered_model("A"),
+            self.store.create_registered_model("AB", tags=[tags[0]]),
+            self.store.create_registered_model("B", tags=[tags[1]]),
+            self.store.create_registered_model("ABA", tags=[tags[0], tags[2]]),
+            self.store.create_registered_model("AAA", tags=tags),
+        ]
+
+        results = self.store.list_registered_model()
+        self.assertEqual(len(results), 5)
+        self._compare_registered_model_names(results, rms)
+
+        results = self.store.list_registered_model(filter_str="A", filter_tags=[tags[0]])
+        self.assertEqual(len(results), 3)
+        self._compare_registered_model_names(results, [rms[1], rms[3], rms[4]])
+
+        results = self.store.list_registered_model(filter_str="AB", filter_tags=[tags[0]])
+        self.assertEqual(len(results), 2)
+        self._compare_registered_model_names(results, [rms[1], rms[3]])
+
+        results = self.store.list_registered_model(filter_str="AAA", filter_tags=tags)
+        self.assertEqual(len(results), 1)
+        self._compare_registered_model_names(results, [rms[-1]])
+
+    @freeze_time("2021-11-11 11:11:11.111000")
+    def test_get_registered_model(self):
+        name = "test_get_RM"
+        tags = ["tag1", "tag2"]
+        fake_datetime = datetime.now()
+        rm = self.store.create_registered_model(name, tags=tags)
+        self.assertEqual(rm.name, name)
+
+        rmd = self.store.get_registered_model(name)
+        self.assertEqual(rmd.name, name)
+        self.assertEqual(rmd.creation_time, fake_datetime)
+        self.assertEqual(rmd.last_updated_time, fake_datetime)
+        self.assertEqual(rmd.description, None)
+        self.assertEqual(rmd.tags, tags)
+
+    def test_add_registered_model_tag(self):
+        name1 = "test_add_RM_tag"
+        name2 = "test_add_RM_tag_2"
+        tags = ["tag1", "tag2"]
+        self.store.create_registered_model(name1, tags=tags)
+        self.store.create_registered_model(name2, tags=tags)
+        new_tag = "new tag"
+        self.store.add_registered_model_tag(name1, new_tag)
+        rmd = self.store.get_registered_model(name1)
+        all_tags = [new_tag] + tags
+        self.assertEqual(rmd.tags, all_tags)
+
+        # test add the same tag
+        same_tag = "tag1"
+        self.store.add_registered_model_tag(name1, same_tag)
+        rmd = self.store.get_registered_model(name1)
+        self.assertEqual(rmd.tags, all_tags)
+
+        # does not affect other models
+        rm2d = self.store.get_registered_model(name2)
+        self.assertEqual(rm2d.tags, tags)
+
+        # cannot set invalid tag
+        with self.assertRaises(SubmarineException):
+            self.store.add_registered_model_tag(name1, None)
+        with self.assertRaises(SubmarineException):
+            self.store.add_registered_model_tag(name1, "")
+
+        # cannot use invalid model name
+        with self.assertRaises(SubmarineException):
+            self.store.add_registered_model_tag(None, new_tag)
+
+        # cannot set tag on deleted registered model
+        self.store.delete_registered_model(name1)
+        with self.assertRaises(SubmarineException):
+            new_tag = "new tag2"
+            self.store.add_registered_model_tag(name1, new_tag)
+
+    def test_delete_registered_model_tag(self):
+        name1 = "test_registered_model"
+        name2 = "test_registered_model_2"
+        tags = ["tag1", "tag2"]
+        self.store.create_registered_model(name1, tags=tags)
+        self.store.create_registered_model(name2, tags=tags)
+        new_tag = "new tag"
+        self.store.add_registered_model_tag(name1, new_tag)
+        self.store.delete_registered_model_tag(name1, new_tag)
+        rmd1 = self.store.get_registered_model(name1)
+        self.assertEqual(rmd1.tags, tags)
+
+        # delete tag that is already deleted
+        with self.assertRaises(SubmarineException):
+            self.store.delete_registered_model_tag(name1, new_tag)
+        rmd1 = self.store.get_registered_model(name1)
+        self.assertEqual(rmd1.tags, tags)
+
+        # does not affect other models
+        rm2d = self.store.get_registered_model(name2)
+        self.assertEqual(rm2d.tags, tags)
+
+        # Cannot delete invalid key
+        with self.assertRaises(SubmarineException):
+            self.store.delete_registered_model_tag(name1, None)
+        with self.assertRaises(SubmarineException):
+            self.store.delete_registered_model_tag(name1, "")
+
+        # Cannot use invalid model name
+        with self.assertRaises(SubmarineException):
+            self.store.delete_registered_model_tag(None, "tag1")
+
+        # Cannot delete tag on deleted (non-existed) registered model
+        self.store.delete_registered_model(name1)
+        with self.assertRaises(SubmarineException):
+            self.store.delete_registered_model_tag(name1, "tag1")
+
+    @freeze_time("2021-11-11 11:11:11.111000")
+    def test_create_model_version(self):
+        name = "test_registered_model"
+        self.store.create_registered_model(name)
+        fake_datetime = datetime.now()
+        mv1 = self.store.create_model_version(name, "path/to/source", "test", "application_1234")
+        self.assertEqual(mv1.name, name)
+        self.assertEqual(mv1.version, 1)
+        self.assertEqual(mv1.creation_time, fake_datetime)
+
+        mvd1 = self.store.get_model_version(mv1.name, mv1.version)
+        self.assertEqual(mvd1.name, name)
+        self.assertEqual(mvd1.user_id, "test")
+        self.assertEqual(mvd1.experiment_id, "application_1234")
+        self.assertEqual(mvd1.current_stage, STAGE_NONE)
+        self.assertEqual(mvd1.creation_time, fake_datetime)
+        self.assertEqual(mvd1.last_updated_time, fake_datetime)
+        self.assertEqual(mvd1.source, "path/to/source")
+        self.assertEqual(mvd1.dataset, None)
+        self.assertEqual(mvd1.dataset, None)
+
+        # new model versions for same name autoincrement versions
+        mv2 = self.store.create_model_version(name, "path/to/source", "test", "application_1234")
+        mvd2 = self.store.get_model_version(name=mv2.name, version=mv2.version)
+        self.assertEqual(mv2.version, 2)
+        self.assertEqual(mvd2.version, 2)
+
+        # create model version with tags
+        tags = ["tag1", "tag2"]
+        mv3 = self.store.create_model_version(
+            name, "path/to/source", "test", "application_1234", tags=tags
+        )
+        mvd3 = self.store.get_model_version(mv3.name, mv3.version)
+        self.assertEqual(mv3.version, 3)
+        self.assertEqual(mv3.tags, tags)
+        self.assertEqual(mvd3.version, 3)
+        self.assertEqual(mvd3.tags, tags)
+
+        # create model version with description
+        description = "A test description."
+        mv4 = self.store.create_model_version(
+            name, "path/to/source", "test", "application_1234", description=description
+        )
+        mvd4 = self.store.get_model_version(mv4.name, mv4.version)
+        self.assertEqual(mv4.version, 4)
+        self.assertEqual(mv4.description, description)
+        self.assertEqual(mvd4.version, 4)
+        self.assertEqual(mvd4.description, description)
+
+    def test_update_model_version_description(self):
+        name = "test_for_update_MV_description"
+        self.store.create_registered_model(name)
+        mv1 = self.store.create_model_version(name, "path/to/source", "test", "application_1234")
+        mvd1 = self.store.get_model_version(mv1.name, mv1.version)
+        self.assertEqual(mvd1.name, name)
+        self.assertEqual(mvd1.version, 1)
+        self.assertEqual(mvd1.description, None)
+
+        # update description
+        fake_datetime = datetime.strptime("2021-11-11 11:11:11.111000", "%Y-%m-%d %H:%M:%S.%f")
+        with freeze_time(fake_datetime):
+            self.store.update_model_version_description(mv1.name, mv1.version, "New description.")
+            mvd2 = self.store.get_model_version(mv1.name, mv1.version)
+            self.assertEqual(mvd2.name, name)
+            self.assertEqual(mvd2.version, 1)
+            self.assertEqual(mvd2.description, "New description.")
+            self.assertEqual(mvd2.last_updated_time, fake_datetime)
+
+    def test_transition_model_version_stage(self):
+        name = "test_transition_MV_stage"
+        self.store.create_registered_model(name)
+        mv1 = self.store.create_model_version(name, "path/to/source", "test", "application_1234")
+        mv2 = self.store.create_model_version(name, "path/to/source", "test", "application_1234")
+
+        fake_datetime = datetime.strptime("2021-11-11 11:11:11.111000", "%Y-%m-%d %H:%M:%S.%f")
+        with freeze_time(fake_datetime):
+            self.store.transition_model_version_stage(mv1.name, mv1.version, STAGE_STAGING)
+            mv1d = self.store.get_model_version(mv1.name, mv1.version)
+            self.assertEqual(mv1d.current_stage, STAGE_STAGING)
+
+            # check last updated time
+            self.assertEqual(mv1d.last_updated_time, fake_datetime)
+            rmd = self.store.get_registered_model(name)
+            self.assertEqual(rmd.last_updated_time, fake_datetime)
+
+        fake_datetime = datetime.strptime("2021-11-11 11:11:22.222000", "%Y-%m-%d %H:%M:%S.%f")
+        with freeze_time(fake_datetime):
+            self.store.transition_model_version_stage(mv1.name, mv1.version, STAGE_PRODUCTION)
+            mv1d = self.store.get_model_version(mv1.name, mv1.version)
+            self.assertEqual(mv1d.current_stage, STAGE_PRODUCTION)
+
+            # check last updated time
+            self.assertEqual(mv1d.last_updated_time, fake_datetime)
+            rmd = self.store.get_registered_model(name)
+            self.assertEqual(rmd.last_updated_time, fake_datetime)
+
+        fake_datetime = datetime.strptime("2021-11-11 11:11:22.333000", "%Y-%m-%d %H:%M:%S.%f")
+        with freeze_time(fake_datetime):
+            self.store.transition_model_version_stage(mv1.name, mv1.version, STAGE_ARCHIVED)
+            mv1d = self.store.get_model_version(mv1.name, mv1.version)
+            self.assertEqual(mv1d.current_stage, STAGE_ARCHIVED)
+
+            # check last updated time
+            self.assertEqual(mv1d.last_updated_time, fake_datetime)
+            rmd = self.store.get_registered_model(name)
+            self.assertEqual(rmd.last_updated_time, fake_datetime)
+
+        # uncanonical stage
+        for uncanonical_stage_name in ["STAGING", "staging", "StAgInG"]:
+            self.store.transition_model_version_stage(mv1.name, mv1.version, STAGE_NONE)
+            self.store.transition_model_version_stage(mv1.name, mv1.version, uncanonical_stage_name)
+
+            mv1d = self.store.get_model_version(mv1.name, mv1.version)
+            self.assertEqual(mv1d.current_stage, STAGE_STAGING)
+
+        # Not matching stages
+        with self.assertRaises(SubmarineException):
+            self.store.transition_model_version_stage(mv1.name, mv1.version, None)
+        # Not matching stages
+        with self.assertRaises(SubmarineException):
+            self.store.transition_model_version_stage(mv1.name, mv1.version, "stage")
+
+        # No change for other model
+        mv2d = self.store.get_model_version(mv2.name, mv2.version)
+        self.assertEqual(mv2d.current_stage, STAGE_NONE)
+
+    def test_delete_model_version(self):
+        name = "test_for_delete_MV"
+        tags = ["tag1", "tag2"]
+        self.store.create_registered_model(name)
+        mv = self.store.create_model_version(
+            name, "path/to/source", "test", "application_1234", tags=tags
+        )
+        mvd = self.store.get_model_version(mv.name, mv.version)
+        self.assertEqual(mvd.name, name)
+
+        self.store.delete_model_version(name=mv.name, version=mv.version)
+
+        # model tags are cascade deleted with the model version
+        with self.assertRaises(SubmarineException):
+            self.store.delete_model_tag(mv.name, mv.version, tags[0])
+        with self.assertRaises(SubmarineException):
+            self.store.delete_model_tag(mv.name, mv.version, tags[1])
+
+        # cannot get a deleted model version
+        with self.assertRaises(SubmarineException):
+            self.store.get_model_version(mv.name, mv.version)
+
+        # cannot update description of a deleted model version
+        with self.assertRaises(SubmarineException):
+            self.store.update_model_version_description(mv.name, mv.version, "New description.")
+
+        # cannot delete a non-existing version
+        with self.assertRaises(SubmarineException):
+            self.store.delete_model_version(name=mv.name, version=None)
+
+        # cannot delete a non-existing model name
+        with self.assertRaises(SubmarineException):
+            self.store.delete_model_version(name=None, version=mv.version)
+
+    @freeze_time("2021-11-11 11:11:11.111000")
+    def test_get_model_version(self):
+        name = "test_for_delete_MV"
+        tags = ["tag1", "tag2"]
+        self.store.create_registered_model(name)
+        fake_datetime = datetime.now()
+        mv = self.store.create_model_version(
+            name,
+            source="path/to/source",
+            user_id="test",
+            experiment_id="application_1234",
+            tags=tags,
+        )
+        self.assertEqual(mv.creation_time, fake_datetime)
+        self.assertEqual(mv.last_updated_time, fake_datetime)
+        mvd = self.store.get_model_version(mv.name, mv.version)
+        self.assertEqual(mvd.name, name)
+        self.assertEqual(mvd.user_id, "test")
+        self.assertEqual(mvd.experiment_id, "application_1234")
+        self.assertEqual(mvd.current_stage, STAGE_NONE)
+        self.assertEqual(mvd.creation_time, fake_datetime)
+        self.assertEqual(mvd.last_updated_time, fake_datetime)
+        self.assertEqual(mvd.source, "path/to/source")
+        self.assertEqual(mvd.dataset, None)
+        self.assertEqual(mvd.description, None)
+        self.assertEqual(mvd.tags, tags)
+
+    def _compare_model_versions(self, results: List[ModelVersion], mvs: List[ModelVersion]) -> None:
+        result_versions = set([result.version for result in results])
+        model_versions = set([mv.version for mv in mvs])
+
+        self.assertEqual(result_versions, model_versions)
+
+    @freeze_time("2021-11-11 11:11:11.111000")
+    def test_list_model_version(self):
+        name = "test_list_MV"
+        self.store.create_registered_model(name)
+        tags = ["tag1", "tag2", "tag3"]
+        mvs = [
+            self.store.create_model_version(name, "path/to/source", "test", "application_1234"),
+            self.store.create_model_version(
+                name, "path/to/source", "test", "application_1234", tags=[tags[0]]
+            ),
+            self.store.create_model_version(
+                name, "path/to/source", "test", "application_1234", tags=[tags[1]]
+            ),
+            self.store.create_model_version(
+                name, "path/to/source", "test", "application_1234", tags=[tags[0], tags[2]]
+            ),
+            self.store.create_model_version(
+                name, "path/to/source", "test", "application_1234", tags=tags
+            ),
+        ]
+
+        results = self.store.list_model_version(name)
+        self.assertEqual(len(results), 5)
+        self._compare_model_versions(results, mvs)
+
+        results = self.store.list_model_version(name, filter_tags=tags[0:1])
+        self.assertEqual(len(results), 3)
+        self._compare_model_versions(results, [mvs[1], mvs[3], mvs[4]])
+
+        results = self.store.list_model_version(name, filter_tags=tags[0:2])
+        self.assertEqual(len(results), 1)
+        self._compare_model_versions(results, [mvs[-1]])
+
+        # empty result
+        other_tag = ["tag4"]
+        results = self.store.list_model_version(name, filter_tags=other_tag)
+        self.assertEqual(len(results), 0)
+        self.assertEqual(results, [])
+
+        # empty result
+        results = self.store.list_registered_model(filter_tags=tags + other_tag)
+        self.assertEqual(len(results), 0)
+        self.assertEqual(results, [])

Review comment:
       This test only contains one registered model name. Adding other registered model name helps to test list_model_version thoroughly when both name and tags are provided.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org