You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@submarine.apache.org by GitBox <gi...@apache.org> on 2021/09/18 17:12:37 UTC

[GitHub] [submarine] pingsutw commented on a change in pull request #752: SUBMARINE-1023. Submarine SDK sqlalchemy store (model registry)

pingsutw commented on a change in pull request #752:
URL: https://github.com/apache/submarine/pull/752#discussion_r711610888



##########
File path: submarine-sdk/pysubmarine/submarine/utils/db_utils.py
##########
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from submarine.store import DEFAULT_SUBMARINE_JDBC_URL
+from submarine.utils import env
+
+_DB_URI_ENV_VAR = "SUBMARINE_DB_URI"
+
+
+_db_uri = None
+
+
+def is_db_uri_set():
+    """Returns True if the DB URI has been set, False otherwise."""
+    if _db_uri or env.get_env(_DB_URI_ENV_VAR):
+        return True
+    return False
+
+
+def set_db_uri(uri):
+    """
+    Set the DB URI. This does not affect the currently active run (if one exists),
+    but takes effect for successive runs.
+    """
+    global _tracking_uri
+    _tracking_uri = uri
+
+
+def get_db_uri():
+    """
+    Get the current DB URI.
+    :return: The tracking URI.

Review comment:
       ```suggestion
       :return: The DB URI.
   ```

##########
File path: submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
##########
@@ -0,0 +1,572 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from contextlib import contextmanager
+from datetime import datetime
+from typing import List, Union
+
+import sqlalchemy
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.orm.session import Session, sessionmaker
+from sqlalchemy.orm.strategy_options import _UnboundLoad
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+from submarine.entities.model_registry.model_version_stages import (
+    STAGE_DELETED_INTERNAL,
+    get_canonical_stage,
+)
+from submarine.exceptions import SubmarineException
+from submarine.store.database.models import (
+    Base,
+    SqlModelTag,
+    SqlModelVersion,
+    SqlRegisteredModel,
+    SqlRegisteredModelTag,
+)
+from submarine.store.model_registry.abstract_store import AbstractStore
+from submarine.utils import extract_db_type_from_uri
+from submarine.utils.validation import (
+    validate_description,
+    validate_model_name,
+    validate_model_version,
+    validate_tag,
+    validate_tags,
+)
+
+_logger = logging.getLogger(__name__)
+
+
+class SqlAlchemyStore(AbstractStore):
+    def __init__(self, db_uri: str) -> None:
+        """
+        Create a database backed store.
+        :param db_uri: The SQLAlchemy database URI string to connect to the database. See
+                       the `SQLAlchemy docs
+                       <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
+                       for format specifications. Submarine supports the dialects ``mysql``.
+        """
+        super(SqlAlchemyStore, self).__init__()
+
+        self.db_uri = db_uri
+        self.db_type = extract_db_type_from_uri(db_uri)
+        self.engine = sqlalchemy.create_engine(db_uri, pool_pre_ping=True)
+        insp = sqlalchemy.inspect(self.engine)
+
+        # Verify that all model registry tables exist.
+        expected_tables = {
+            SqlModelVersion.__tablename__,
+            SqlModelTag.__tablename__,
+            SqlRegisteredModel.__tablename__,
+            SqlRegisteredModelTag.__tablename__,
+        }
+        if len(expected_tables & set(insp.get_table_names())) == 0:
+            SqlAlchemyStore._initialize_tables(self.engine)
+        Base.metadata.bind = self.engine
+        SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
+        self.ManagedSessionMaker = self._get_managed_session_maker(SessionMaker)
+
+    @staticmethod
+    def _initialize_tables(engine: Engine):
+        _logger.info("Creating initial Submarine database tables...")
+        Base.metadata.create_all(engine)
+
+    @staticmethod
+    def _get_managed_session_maker(SessionMaker: sessionmaker):
+        """
+        Creates a factory for producing exception-safe SQLAlchemy sessions that are made available
+        using a context manager. Any session produced by this factory is automatically committed
+        if no exceptions are encountered within its associated context. If an exception is
+        encountered, the session is rolled back. Finally, any session produced by this factory is
+        automatically closed when the session's associated context is exited.
+        """
+
+        @contextmanager
+        def make_managed_session():
+            """Provide a transactional scope around a series of operations."""
+            session: Session = SessionMaker()
+            try:
+                yield session
+                session.commit()
+            except SubmarineException:
+                session.rollback()
+                raise
+            except Exception as e:
+                session.rollback()
+                raise SubmarineException(e)
+            finally:
+                session.close()
+
+        return make_managed_session
+
+    @staticmethod
+    def _get_eager_registered_model_query_options() -> List[_UnboundLoad]:
+        """
+        :return A list of SQLAlchemy query options that can be used to eagerly
+                load the following registered model attributes
+                when fetching a registered model: ``registered_model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlRegisteredModel.registered_model_tags)]
+
+    @staticmethod
+    def _get_eager_model_version_query_options():
+        """
+        :return: A list of SQLAlchemy query options that can be used to eagerly
+                load the following model version attributes
+                when fetching a model version: ``model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlModelVersion.model_tags)]
+
+    def _save_to_db(self, session: Session, objs: Union[list, object]) -> None:
+        """
+        Store in db
+        """
+        if type(objs) is list:
+            session.add_all(objs)
+        else:
+            # single object
+            session.add(objs)
+
+    def create_registered_model(
+        self, name: str, description: str = None, tags: List[str] = None
+    ) -> RegisteredModel:
+        """
+        Create a new registered model in backend store.
+        :param name: Name of the new model. This is expected to be unique in the backend store.
+        :param description: Description of the model.
+        :param tags: A list of string associated with this registered model.
+        :return: A single object of :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 created in the backend.
+        """
+        validate_model_name(name)
+        validate_tags(tags)
+        validate_description(description)
+        with self.ManagedSessionMaker() as session:
+            try:
+                creation_time = datetime.now()
+                registered_model = SqlRegisteredModel(
+                    name=name,
+                    creation_time=creation_time,
+                    last_updated_time=creation_time,
+                    description=description,
+                    registered_model_tags=[SqlRegisteredModelTag(tag=tag) for tag in tags or []],
+                )
+                self._save_to_db(session, registered_model)
+                session.flush()
+                return registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    f"Registered Model (name={name}) already exists.\nError: {str(e)}"
+                )
+
+    @classmethod
+    def _get_registered_model(
+        cls, session: Session, name: str, eager: bool = False
+    ) -> SqlRegisteredModel:
+        """
+        :param eager: If ``True``, eagerly loads the registered model's tags.
+                      If ``False``, these attributes are not eagerly loaded and
+                      will be loaded when their corresponding object properties
+                      are accessed from the resulting ``SqlRegisteredModel`` object.
+        """
+        validate_model_name(name)
+        query_options = cls._get_eager_registered_model_query_options() if eager else []
+        models: List[SqlRegisteredModel] = (
+            session.query(SqlRegisteredModel)
+            .options(*query_options)
+            .filter(SqlRegisteredModel.name == name)
+            .all()
+        )
+
+        if len(models) == 0:
+            raise SubmarineException(f"Registered Model with name={name} not found")
+        elif len(models) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model with name={name}.\nFound {len(models)}"
+            )
+        else:
+            return models[0]
+
+    def update_registered_model_description(self, name: str, description: str) -> RegisteredModel:
+        """
+        Update description of the registered model.
+        :param name: Registered model name.
+        :param description: New description.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        validate_description(description)
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            sql_registered_model.description = description
+            sql_registered_model.last_updated_time = datetime.now()
+            self._save_to_db(session, sql_registered_model)
+            session.flush()
+            return sql_registered_model.to_submarine_entity()
+
+    def rename_registered_model(self, name: str, new_name: str) -> RegisteredModel:
+        """
+        Rename the registered model.
+        :param name: Registered model name.
+        :param new_name: New proposed name.
+        :return: A single updated :py:class:`submarine.entities.model_registry.RegisteredModel`
+                 object.
+        """
+        validate_model_name(new_name)
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            try:
+                update_time = datetime.now()
+                sql_registered_model.name = new_name
+                sql_registered_model.last_updated_time = update_time
+                for sql_model_version in sql_registered_model.model_versions:
+                    sql_model_version.name = new_name
+                    sql_model_version.last_updated_time = update_time
+                self._save_to_db(
+                    session, [sql_registered_model] + sql_registered_model.model_versions
+                )
+                session.flush()
+                return sql_registered_model.to_submarine_entity()
+            except sqlalchemy.exc.IntegrityError as e:
+                raise SubmarineException(
+                    f"Registered Model (name={name}) already exists. Error: {str(e)}"
+                )
+
+    def delete_registered_model(self, name: str) -> None:
+        """
+        Delete the registered model.
+        :param name: Registered model name.
+        :return: None
+        """
+        with self.ManagedSessionMaker() as session:
+            sql_registered_model = self._get_registered_model(session, name)
+            session.delete(sql_registered_model)
+
+    def list_registered_model(
+        self, filter_str: str = None, filter_tags: List[str] = None
+    ) -> List[RegisteredModel]:
+        """
+        List of all registered models.
+        :param filter_string: Filter query string, defaults to searching all registered models.
+        :param filter_tags: Filter tags, defaults not to filter any tags.
+        :return: A List of :py:class:`submarine.entities.model_registry.RegisteredModel` objects
+                that satisfy the search expressions.
+        """
+        conditions = []
+        if filter_tags is not None:
+            conditions = [
+                SqlRegisteredModel.registered_model_tags.any(
+                    SqlRegisteredModelTag.tag.contains(tag)
+                )
+                for tag in filter_tags
+            ]
+        if filter_str is not None:
+            conditions.append(SqlRegisteredModel.name.startswith(filter_str))
+        with self.ManagedSessionMaker() as session:
+            registered_models = session.query(SqlRegisteredModel).filter(*conditions).all()
+            return [
+                registered_model.to_submarine_entity() for registered_model in registered_models
+            ]
+
+    def get_registered_model(self, name: str) -> RegisteredModel:
+        """
+        Get registered model instance by name.
+        :param name: Registered model name.
+        :return: A single :py:class:`submarine.entities.model_registry.RegisteredModel` object.
+        """
+        with self.ManagedSessionMaker() as session:
+            return self._get_registered_model(session, name, True).to_submarine_entity()
+
+    @classmethod
+    def _get_registered_model_tag(
+        cls, session: Session, name: str, tag: str
+    ) -> SqlRegisteredModelTag:
+        tags = (
+            session.query(SqlRegisteredModelTag)
+            .filter(SqlRegisteredModelTag.name == name, SqlRegisteredModelTag.tag == tag)
+            .all()
+        )
+        if len(tags) == 0:
+            raise SubmarineException(f"Registered model tag with name={name}, tag={tag} not found")
+        elif len(tags) > 1:
+            raise SubmarineException(
+                f"Expected only 1 registered model version tag with name={name}, tag={tag}. Found"
+                f" {len(tags)}."
+            )
+        else:
+            return tags[0]
+
+    def add_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Add a tag for the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            session.merge(SqlRegisteredModelTag(name=name, tag=tag))
+
+    def delete_registered_model_tag(self, name: str, tag: str) -> None:
+        """
+        Delete a tag associated with the registered model.
+        :param name: Registered model name.
+        :param tag: String of tag value.
+        :return: None
+        """
+        validate_model_name(name)
+        validate_tag(tag)
+        with self.ManagedSessionMaker() as session:
+            # check if registered model exists
+            self._get_registered_model(session, name)
+            existing_tag = self._get_registered_model_tag(session, name, tag)
+            session.delete(existing_tag)
+
+    def create_model_version(
+        self,
+        name: str,
+        source: str,
+        user_id: str,
+        experiment_id: str,
+        dataset: str = None,
+        description: str = None,
+        tags: List[str] = None,

Review comment:
       What's different between the tags in `model version` and `registered model`?

##########
File path: submarine-sdk/pysubmarine/submarine/utils/db_utils.py
##########
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from submarine.store import DEFAULT_SUBMARINE_JDBC_URL
+from submarine.utils import env
+
+_DB_URI_ENV_VAR = "SUBMARINE_DB_URI"
+
+
+_db_uri = None
+
+
+def is_db_uri_set():
+    """Returns True if the DB URI has been set, False otherwise."""
+    if _db_uri or env.get_env(_DB_URI_ENV_VAR):
+        return True
+    return False
+
+
+def set_db_uri(uri):
+    """
+    Set the DB URI. This does not affect the currently active run (if one exists),
+    but takes effect for successive runs.
+    """
+    global _tracking_uri
+    _tracking_uri = uri

Review comment:
       should be `_db_uri`?

##########
File path: submarine-sdk/pysubmarine/github-actions/test-requirements.txt
##########
@@ -23,11 +23,12 @@ pytest==3.2.1
 pytest-cov==2.6.0
 pytest-localserver==0.5.0
 pylint==2.5.2
-sqlalchemy==1.3.0
+sqlalchemy >= 1.4.0

Review comment:
       We could remove it, right? we already have it in setup.py

##########
File path: submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
##########
@@ -0,0 +1,234 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABCMeta, abstractmethod
+from typing import List
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+
+
+class AbstractStore:
+    """
+    Abstract class for Backend model registry
+    This class defines the API interface for frontends to connect with various types of backends.
+    """
+
+    __metaclass__ = ABCMeta
+
+    def __init__(self) -> None:
+        """
+        Empty constructor for now. This is deliberately not marked as abstract, else every
+        derived class would be forced to create one.
+        """
+        pass
+
+    @abstractmethod
+    def create_registered_model(

Review comment:
       Do we need to put these functions here? https://github.com/apache/submarine/blob/master/submarine-sdk/pysubmarine/submarine/store/abstract_store.py

##########
File path: submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
##########
@@ -0,0 +1,572 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from contextlib import contextmanager
+from datetime import datetime
+from typing import List, Union
+
+import sqlalchemy
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.orm.session import Session, sessionmaker
+from sqlalchemy.orm.strategy_options import _UnboundLoad
+
+from submarine.entities.model_registry import ModelVersion, RegisteredModel
+from submarine.entities.model_registry.model_version_stages import (
+    STAGE_DELETED_INTERNAL,
+    get_canonical_stage,
+)
+from submarine.exceptions import SubmarineException
+from submarine.store.database.models import (
+    Base,
+    SqlModelTag,
+    SqlModelVersion,
+    SqlRegisteredModel,
+    SqlRegisteredModelTag,
+)
+from submarine.store.model_registry.abstract_store import AbstractStore
+from submarine.utils import extract_db_type_from_uri
+from submarine.utils.validation import (
+    validate_description,
+    validate_model_name,
+    validate_model_version,
+    validate_tag,
+    validate_tags,
+)
+
+_logger = logging.getLogger(__name__)
+
+
+class SqlAlchemyStore(AbstractStore):
+    def __init__(self, db_uri: str) -> None:
+        """
+        Create a database backed store.
+        :param db_uri: The SQLAlchemy database URI string to connect to the database. See
+                       the `SQLAlchemy docs
+                       <https://docs.sqlalchemy.org/en/latest/core/engines.html#database-urls>`_
+                       for format specifications. Submarine supports the dialects ``mysql``.
+        """
+        super(SqlAlchemyStore, self).__init__()
+
+        self.db_uri = db_uri
+        self.db_type = extract_db_type_from_uri(db_uri)
+        self.engine = sqlalchemy.create_engine(db_uri, pool_pre_ping=True)
+        insp = sqlalchemy.inspect(self.engine)
+
+        # Verify that all model registry tables exist.
+        expected_tables = {
+            SqlModelVersion.__tablename__,
+            SqlModelTag.__tablename__,
+            SqlRegisteredModel.__tablename__,
+            SqlRegisteredModelTag.__tablename__,
+        }
+        if len(expected_tables & set(insp.get_table_names())) == 0:
+            SqlAlchemyStore._initialize_tables(self.engine)
+        Base.metadata.bind = self.engine
+        SessionMaker = sqlalchemy.orm.sessionmaker(bind=self.engine)
+        self.ManagedSessionMaker = self._get_managed_session_maker(SessionMaker)
+
+    @staticmethod
+    def _initialize_tables(engine: Engine):
+        _logger.info("Creating initial Submarine database tables...")
+        Base.metadata.create_all(engine)
+
+    @staticmethod
+    def _get_managed_session_maker(SessionMaker: sessionmaker):
+        """
+        Creates a factory for producing exception-safe SQLAlchemy sessions that are made available
+        using a context manager. Any session produced by this factory is automatically committed
+        if no exceptions are encountered within its associated context. If an exception is
+        encountered, the session is rolled back. Finally, any session produced by this factory is
+        automatically closed when the session's associated context is exited.
+        """
+
+        @contextmanager
+        def make_managed_session():
+            """Provide a transactional scope around a series of operations."""
+            session: Session = SessionMaker()
+            try:
+                yield session
+                session.commit()
+            except SubmarineException:
+                session.rollback()
+                raise
+            except Exception as e:
+                session.rollback()
+                raise SubmarineException(e)
+            finally:
+                session.close()
+
+        return make_managed_session
+
+    @staticmethod
+    def _get_eager_registered_model_query_options() -> List[_UnboundLoad]:
+        """
+        :return A list of SQLAlchemy query options that can be used to eagerly
+                load the following registered model attributes
+                when fetching a registered model: ``registered_model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlRegisteredModel.registered_model_tags)]
+
+    @staticmethod
+    def _get_eager_model_version_query_options():
+        """
+        :return: A list of SQLAlchemy query options that can be used to eagerly
+                load the following model version attributes
+                when fetching a model version: ``model_tags``.
+        """
+        return [sqlalchemy.orm.subqueryload(SqlModelVersion.model_tags)]
+
+    def _save_to_db(self, session: Session, objs: Union[list, object]) -> None:
+        """
+        Store in db
+        """
+        if type(objs) is list:
+            session.add_all(objs)
+        else:
+            # single object
+            session.add(objs)
+
+    def create_registered_model(

Review comment:
       ditto.
   https://github.com/apache/submarine/blob/master/submarine-sdk/pysubmarine/submarine/store/sqlalchemy_store.py

##########
File path: submarine-sdk/pysubmarine/submarine/utils/validation.py
##########
@@ -116,6 +117,45 @@ def validate_param(key, value):
     _validate_length_limit("Param value", MAX_PARAM_VAL_LENGTH, str(value))
 
 
+def validate_tags(tags: Optional[List[str]]) -> None:
+    if tags is not None and not isinstance(tags, list):
+        raise SubmarineException("parameter tags must be list or None.")
+    for tag in tags or []:
+        validate_tag(tag)
+
+
+def validate_tag(tag: str) -> None:
+    """Check that `tag` is a valid tag value and raise an exception if it isn't."""
+    # Reuse param & metric check.
+    if tag is None or tag == "":
+        raise SubmarineException("Tag cannot be empty.")
+    if not _VALID_PARAM_AND_METRIC_NAMES.match(tag):
+        raise SubmarineException("Invalid tag name: '%s'. %s" % (tag, _BAD_CHARACTERS_MESSAGE))

Review comment:
       ```suggestion
           raise SubmarineException(f"Invalid tag name: {tag}. {_BAD_CHARACTERS_MESSAGE}")
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@submarine.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org