You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by ku...@apache.org on 2021/11/23 05:17:38 UTC

[submarine] branch master updated: SUBMARINE-1077. Model descriptive file

This is an automated email from the ASF dual-hosted git repository.

kuanhsun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git


The following commit(s) were added to refs/heads/master by this push:
     new b4de88b  SUBMARINE-1077. Model descriptive file
b4de88b is described below

commit b4de88b08f19285a561f27911c86877fa4e18b13
Author: jeff-901 <b0...@ntu.edu.tw>
AuthorDate: Sat Nov 20 09:24:08 2021 +0800

    SUBMARINE-1077. Model descriptive file
    
    ### What is this PR for?
    Create a file under each artifact. This file contains model type, model input dimensions and model output dimensions.
    Also, change the artifact save path in minio pod for serving.
    ### What type of PR is it?
    Feature
    
    ### Todos
    
    ### What is the Jira issue?
    https://issues.apache.org/jira/browse/SUBMARINE-1077
    
    ### How should this be tested?
    
    ### Screenshots (if appropriate)
    
    ### Questions:
    * Do the license files need updating? No
    * Are there breaking changes for older versions? No
    * Does this need new documentation? Yes
    
    Author: jeff-901 <b0...@ntu.edu.tw>
    
    Signed-off-by: jingch1213717 <ku...@apache.org>
    
    Closes #801 from jeff-901/SUBMARINE-1077 and squashes the following commits:
    
    aff3ba11 [jeff-901] delete data_type and name in description file
    2cd0a4a4 [jeff-901] fix style
    312e237f [jeff-901] uodate document
    234e989e [jeff-901] fix test
    8b915708 [jeff-901] fix test
    96f5b974 [jeff-901] add model type in db
    eec42890 [jeff-901] add model type
    72d7a8ae [jeff-901] fix style
    2ee1ca8b [jeff-901] description file
---
 dev-support/database/submarine-model.sql           |  1 +
 .../pysubmarine/submarine/artifacts/repository.py  | 15 ++---
 .../entities/model_registry/model_version.py       |  7 ++
 .../pysubmarine/submarine/store/database/models.py |  6 ++
 .../store/model_registry/abstract_store.py         |  1 +
 .../store/model_registry/sqlalchemy_store.py       |  2 +
 .../pysubmarine/submarine/tracking/client.py       | 48 ++++++++++++--
 .../entities/model_registry/test_model_version.py  |  7 ++
 .../store/model_registry/test_sqlalchemy_store.py  | 75 +++++++++++++++-------
 .../database/entities/ModelVersionEntity.java      | 11 ++++
 .../database/mappers/ModelVersionMapper.xml        |  8 ++-
 .../server/model/database/ModelVersionTagTest.java |  1 +
 .../server/model/database/ModelVersionTest.java    | 16 +++--
 .../server/rest/ModelVersionRestApiTest.java       |  4 ++
 website/docs/userDocs/submarine-sdk/tracking.md    | 12 ++++
 15 files changed, 169 insertions(+), 45 deletions(-)

diff --git a/dev-support/database/submarine-model.sql b/dev-support/database/submarine-model.sql
index abad0bd..3ace957 100644
--- a/dev-support/database/submarine-model.sql
+++ b/dev-support/database/submarine-model.sql
@@ -36,6 +36,7 @@ CREATE TABLE `model_version` (
 	`source` VARCHAR(512) NOT NULL COMMENT 'Model saved link',
 	`user_id` VARCHAR(64) NOT NULL COMMENT 'Id of the created user',
 	`experiment_id` VARCHAR(64) NOT NULL,
+	`model_type` VARCHAR(64) NOT NULL COMMENT 'Type of model',
 	`current_stage` VARCHAR(64) COMMENT 'Model stage ex: None, production...',
 	`creation_time` DATETIME(3) COMMENT 'Millisecond precision',
 	`last_updated_time` DATETIME(3) COMMENT 'Millisecond precision',
diff --git a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
index 9bee7ae..30622c2 100644
--- a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
+++ b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
@@ -27,31 +27,30 @@ class Repository:
             endpoint_url=os.environ.get("MLFLOW_S3_ENDPOINT_URL"),
         )
         self.dest_path = experiment_id
+        self.bucket = "submarine"
 
     def _upload_file(self, local_file: str, bucket: str, key: str) -> None:
         self.client.upload_file(Filename=local_file, Bucket=bucket, Key=key)
 
     def _list_artifact_subfolder(self, artifact_path: str):
         response = self.client.list_objects(
-            Bucket="submarine",
+            Bucket=self.bucket,
             Prefix=os.path.join(self.dest_path, artifact_path) + "/",
             Delimiter="/",
         )
         return response.get("CommonPrefixes")
 
     def log_artifact(self, local_file: str, artifact_path: str) -> None:
-        bucket = "submarine"
         dest_path = self.dest_path
         dest_path = os.path.join(dest_path, artifact_path)
         dest_path = os.path.join(dest_path, os.path.basename(local_file))
         self._upload_file(
             local_file=local_file,
-            bucket=bucket,
+            bucket=self.bucket,
             key=dest_path,
         )
 
     def log_artifacts(self, local_dir: str, artifact_path: str) -> str:
-        bucket = "submarine"
         dest_path = self.dest_path
         list_of_subfolder = self._list_artifact_subfolder(artifact_path)
         if list_of_subfolder is None:
@@ -68,16 +67,16 @@ class Repository:
             for f in filenames:
                 self._upload_file(
                     local_file=os.path.join(root, f),
-                    bucket=bucket,
+                    bucket=self.bucket,
                     key=os.path.join(upload_path, f),
                 )
-        return f"s3://{bucket}/{dest_path}"
+        return f"s3://{self.bucket}/{dest_path}"
 
     def delete_folder(self) -> None:
-        objects_to_delete = self.client.list_objects(Bucket="submarine", Prefix=self.dest_path)
+        objects_to_delete = self.client.list_objects(Bucket=self.bucket, Prefix=self.dest_path)
         if objects_to_delete.get("Contents") is not None:
             delete_keys: dict = {"Objects": []}
             delete_keys["Objects"] = [
                 {"Key": k} for k in [obj["Key"] for obj in objects_to_delete.get("Contents")]
             ]
-            self.client.delete_objects(Bucket="submarine", Delete=delete_keys)
+            self.client.delete_objects(Bucket=self.bucket, Delete=delete_keys)
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
index 0b43e0a..4d2f3fb 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
@@ -28,6 +28,7 @@ class ModelVersion(_SubmarineObject):
         source,
         user_id,
         experiment_id,
+        model_type,
         current_stage,
         creation_time,
         last_updated_time,
@@ -40,6 +41,7 @@ class ModelVersion(_SubmarineObject):
         self._source = source
         self._user_id = user_id
         self._experiment_id = experiment_id
+        self._model_type = model_type
         self._current_stage = current_stage
         self._creation_time = creation_time
         self._last_updated_time = last_updated_time
@@ -73,6 +75,11 @@ class ModelVersion(_SubmarineObject):
         return self._experiment_id
 
     @property
+    def model_type(self):
+        """String. Type of model."""
+        return self._model_type
+
+    @property
     def creation_time(self):
         """Datetime object. The creation datetime of this version."""
         return self._creation_time
diff --git a/submarine-sdk/pysubmarine/submarine/store/database/models.py b/submarine-sdk/pysubmarine/submarine/store/database/models.py
index ff55c2a..70bb75e 100644
--- a/submarine-sdk/pysubmarine/submarine/store/database/models.py
+++ b/submarine-sdk/pysubmarine/submarine/store/database/models.py
@@ -196,6 +196,11 @@ class SqlModelVersion(Base):
     ID to which this version of model belongs to.
     """
 
+    model_type = Column(String(64), nullable=False)
+    """
+    Type of model.
+    """
+
     current_stage = Column(String(64), default=STAGE_NONE)
     """
     Current stage of this version of model: it can be `None`, `Developing`,
@@ -254,6 +259,7 @@ class SqlModelVersion(Base):
             source=self.source,
             user_id=self.user_id,
             experiment_id=self.experiment_id,
+            model_type=self.model_type,
             current_stage=self.current_stage,
             creation_time=self.creation_time,
             last_updated_time=self.last_updated_time,
diff --git a/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py b/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
index a5e81ce..a06ed99 100644
--- a/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
+++ b/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
@@ -129,6 +129,7 @@ class AbstractStore:
         source: str,
         user_id: str,
         experiment_id: str,
+        model_type: str,
         dataset: str = None,
         description: str = None,
         tags: List[str] = None,
diff --git a/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py b/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
index 47fb075..3825716 100644
--- a/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
@@ -343,6 +343,7 @@ class SqlAlchemyStore(AbstractStore):
         source: str,
         user_id: str,
         experiment_id: str,
+        model_type: str,
         dataset: str = None,
         description: str = None,
         tags: List[str] = None,
@@ -380,6 +381,7 @@ class SqlAlchemyStore(AbstractStore):
                     source=source,
                     user_id=user_id,
                     experiment_id=experiment_id,
+                    model_type=model_type,
                     creation_time=creation_time,
                     last_updated_time=creation_time,
                     dataset=dataset,
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/client.py b/submarine-sdk/pysubmarine/submarine/tracking/client.py
index 982dfe8..25ed6b7 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/client.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/client.py
@@ -12,10 +12,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import json
 import os
 import re
 import tempfile
 import time
+from typing import Any, Dict
 
 import submarine
 from submarine.artifacts.repository import Repository
@@ -92,7 +94,13 @@ class SubmarineClient(object):
         self.store.log_param(job_id, param)
 
     def save_model(
-        self, model_type: str, model, artifact_path: str, registered_model_name: str = None
+        self,
+        model_type: str,
+        model,
+        artifact_path: str,
+        registered_model_name: str = None,
+        input_dim: list = None,
+        output_dim: list = None,
     ) -> None:
         """
         Save a model into the minio pod.
@@ -101,24 +109,53 @@ class SubmarineClient(object):
         :param artifact_path: Relative path of the artifact in the minio pod.
         :param registered_model_name: If not None, register model into the model registry with
                                       this name. If None, the model only be saved in minio pod.
+        :param input_dim: Save the input dimension of the given model to the description file.
+        :param output_dim: Save the output dimension of the given model to the description file.
         """
         pattern = r"[0-9A-Za-z][0-9A-Za-z-_]*[0-9A-Za-z]|[0-9A-Za-z]"
         if not re.fullmatch(pattern, artifact_path):
             raise Exception(
-                "Artifact_path must only contains numbers, characters, hyphen and underscore.      "
-                "        Artifact_path must starts and ends with numbers or characters."
+                "Artifact_path must only contains numbers, characters, hyphen and underscore. "
+                "Artifact_path must starts and ends with numbers or characters."
             )
         with tempfile.TemporaryDirectory() as tempdir:
+            description: Dict[str, Any] = dict()
+            model_save_dir = os.path.join(tempdir, "1")
+            os.mkdir(model_save_dir)
             if model_type == "pytorch":
                 import submarine.models.pytorch
 
-                submarine.models.pytorch.save_model(model, tempdir)
+                if input_dim is None or output_dim is None:
+                    raise Exception(
+                        "Saving pytorch model needs to provide input and output dimension for"
+                        " serving."
+                    )
+                submarine.models.pytorch.save_model(model, model_save_dir)
             elif model_type == "tensorflow":
                 import submarine.models.tensorflow
 
-                submarine.models.tensorflow.save_model(model, tempdir)
+                submarine.models.tensorflow.save_model(model, model_save_dir)
             else:
                 raise Exception("No valid type of model has been matched to {}".format(model_type))
+
+            # Write description file
+            if input_dim is not None:
+                description["input"] = [
+                    {
+                        "dims": input_dim,
+                    }
+                ]
+            if output_dim is not None:
+                description["output"] = [
+                    {
+                        "dims": output_dim,
+                    }
+                ]
+            description["model_type"] = model_type
+            with open(os.path.join(tempdir, "description.json"), "w") as f:
+                json.dump(description, f)
+
+            # Log all files into minio
             source = self.artifact_repo.log_artifacts(tempdir, artifact_path)
 
         # Register model
@@ -132,4 +169,5 @@ class SubmarineClient(object):
                 source=source,
                 user_id="",  # TODO(jeff-901): the user id is needed to be specified.
                 experiment_id=utils.get_job_id(),
+                model_type=model_type,
             )
diff --git a/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py b/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py
index 9cead79..1379c08 100644
--- a/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py
+++ b/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py
@@ -26,6 +26,7 @@ class TestModelVersion:
         "source": "path/to/source",
         "user_id": "admin",
         "experiment_id": "experiment_1",
+        "model_type": "tensorflow",
         "current_stage": STAGE_NONE,
         "creation_time": datetime.now(),
         "last_updated_time": datetime.now(),
@@ -42,6 +43,7 @@ class TestModelVersion:
         source,
         user_id,
         experiment_id,
+        model_type,
         current_stage,
         creation_time,
         last_updated_time,
@@ -55,6 +57,7 @@ class TestModelVersion:
         assert model_metadata.source == source
         assert model_metadata.user_id == user_id
         assert model_metadata.experiment_id == experiment_id
+        assert model_metadata.model_type == model_type
         assert model_metadata.current_stage == current_stage
         assert model_metadata.creation_time == creation_time
         assert model_metadata.last_updated_time == last_updated_time
@@ -69,6 +72,7 @@ class TestModelVersion:
             self.default_data["source"],
             self.default_data["user_id"],
             self.default_data["experiment_id"],
+            self.default_data["model_type"],
             self.default_data["current_stage"],
             self.default_data["creation_time"],
             self.default_data["last_updated_time"],
@@ -83,6 +87,7 @@ class TestModelVersion:
             self.default_data["source"],
             self.default_data["user_id"],
             self.default_data["experiment_id"],
+            self.default_data["model_type"],
             self.default_data["current_stage"],
             self.default_data["creation_time"],
             self.default_data["last_updated_time"],
@@ -101,6 +106,7 @@ class TestModelVersion:
             self.default_data["source"],
             self.default_data["user_id"],
             self.default_data["experiment_id"],
+            self.default_data["model_type"],
             self.default_data["current_stage"],
             self.default_data["creation_time"],
             self.default_data["last_updated_time"],
@@ -115,6 +121,7 @@ class TestModelVersion:
             self.default_data["source"],
             self.default_data["user_id"],
             self.default_data["experiment_id"],
+            self.default_data["model_type"],
             self.default_data["current_stage"],
             self.default_data["creation_time"],
             self.default_data["last_updated_time"],
diff --git a/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py b/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py
index 6d57c2e..afe3217 100644
--- a/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py
@@ -107,8 +107,12 @@ class TestSqlAlchemyStore(unittest.TestCase):
         name = "test_rename_RM"
         new_name = "test_rename_RN_new"
         rm = self.store.create_registered_model(name)
-        self.store.create_model_version(name, "path/to/source1", "test", "application_1234")
-        self.store.create_model_version(name, "path/to/source2", "test", "application_1235")
+        self.store.create_model_version(
+            name, "path/to/source1", "test", "application_1234", "tensorflow"
+        )
+        self.store.create_model_version(
+            name, "path/to/source2", "test", "application_1235", "tensorflow"
+        )
         mv1d = self.store.get_model_version(name, 1)
         mv2d = self.store.get_model_version(name, 2)
         self.assertEqual(rm.name, name)
@@ -142,10 +146,10 @@ class TestSqlAlchemyStore(unittest.TestCase):
         rm2 = self.store.create_registered_model(name2, tags=rm_tags)
         mv_tags = ["mv_tag1", "mv_tag2"]
         rm1mv1 = self.store.create_model_version(
-            rm1.name, "path/to/source1", "test", "application_1234", tags=mv_tags
+            rm1.name, "path/to/source1", "test", "application_1234", "tensorflow", tags=mv_tags
         )
         rm2mv1 = self.store.create_model_version(
-            rm2.name, "path/to/source2", "test", "application_1234", tags=mv_tags
+            rm2.name, "path/to/source2", "test", "application_1234", "tensorflow", tags=mv_tags
         )
 
         # check store
@@ -380,7 +384,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
         self.store.create_registered_model(model_name)
         fake_datetime = datetime.now()
         mv1 = self.store.create_model_version(
-            model_name, "path/to/source1", "test", "application_1234"
+            model_name, "path/to/source1", "test", "application_1234", "tensorflow"
         )
         self.assertEqual(mv1.name, model_name)
         self.assertEqual(mv1.version, 1)
@@ -390,6 +394,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
         self.assertEqual(m1d.name, model_name)
         self.assertEqual(m1d.user_id, "test")
         self.assertEqual(m1d.experiment_id, "application_1234")
+        self.assertEqual(m1d.model_type, "tensorflow")
         self.assertEqual(m1d.current_stage, STAGE_NONE)
         self.assertEqual(m1d.creation_time, fake_datetime)
         self.assertEqual(m1d.last_updated_time, fake_datetime)
@@ -398,7 +403,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
 
         # new model for same registered model autoincrement version
         m2 = self.store.create_model_version(
-            model_name, "path/to/source2", "test", "application_1234"
+            model_name, "path/to/source2", "test", "application_1234", "tensorflow"
         )
         m2d = self.store.get_model_version(m2.name, m2.version)
         self.assertEqual(m2.version, 2)
@@ -407,7 +412,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
         # create model with tags
         tags = ["tag1", "tag2"]
         m3 = self.store.create_model_version(
-            model_name, "path/to/source3", "test", "application_1234", tags=tags
+            model_name, "path/to/source3", "test", "application_1234", "tensorflow", tags=tags
         )
         m3d = self.store.get_model_version(m3.name, m3.version)
         self.assertEqual(m3.version, 3)
@@ -418,7 +423,12 @@ class TestSqlAlchemyStore(unittest.TestCase):
         # create model with description
         description = "A test description."
         m4 = self.store.create_model_version(
-            model_name, "path/to/source4", "test", "application_1234", description=description
+            model_name,
+            "path/to/source4",
+            "test",
+            "application_1234",
+            "tensorflow",
+            description=description,
         )
         m4d = self.store.get_model_version(m4.name, m4.version)
         self.assertEqual(m4.version, 4)
@@ -429,7 +439,9 @@ class TestSqlAlchemyStore(unittest.TestCase):
     def test_update_model_version_description(self):
         name = "test_update_MV_description"
         self.store.create_registered_model(name)
-        mv1 = self.store.create_model_version(name, "path/to/source", "test", "application_1234")
+        mv1 = self.store.create_model_version(
+            name, "path/to/source", "test", "application_1234", "tensorflow"
+        )
         m1d = self.store.get_model_version(mv1.name, mv1.version)
         self.assertEqual(m1d.name, name)
         self.assertEqual(m1d.version, 1)
@@ -448,8 +460,12 @@ class TestSqlAlchemyStore(unittest.TestCase):
     def test_transition_model_version_stage(self):
         name = "test_transition_MV_stage"
         self.store.create_registered_model(name)
-        mv1 = self.store.create_model_version(name, "path/to/source1", "test", "application_1234")
-        m2 = self.store.create_model_version(name, "path/to/source2", "test", "application_1234")
+        mv1 = self.store.create_model_version(
+            name, "path/to/source1", "test", "application_1234", "tensorflow"
+        )
+        m2 = self.store.create_model_version(
+            name, "path/to/source2", "test", "application_1234", "tensorflow"
+        )
 
         fake_datetime = datetime.strptime("2021-11-11 11:11:11.111000", "%Y-%m-%d %H:%M:%S.%f")
         with freeze_time(fake_datetime):
@@ -508,7 +524,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
         tags = ["tag1", "tag2"]
         self.store.create_registered_model(name)
         mv = self.store.create_model_version(
-            name, "path/to/source", "test", "application_1234", tags=tags
+            name, "path/to/source", "test", "application_1234", "tensorflow", tags=tags
         )
         mvd = self.store.get_model_version(mv.name, mv.version)
         self.assertEqual(mvd.name, name)
@@ -548,6 +564,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
             source="path/to/source",
             user_id="test",
             experiment_id="application_1234",
+            model_type="tensorflow",
             tags=tags,
         )
         self.assertEqual(mv.creation_time, fake_datetime)
@@ -556,6 +573,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
         self.assertEqual(mvd.name, name)
         self.assertEqual(mvd.user_id, "test")
         self.assertEqual(mvd.experiment_id, "application_1234")
+        self.assertEqual(mvd.model_type, "tensorflow")
         self.assertEqual(mvd.current_stage, STAGE_NONE)
         self.assertEqual(mvd.creation_time, fake_datetime)
         self.assertEqual(mvd.last_updated_time, fake_datetime)
@@ -578,18 +596,25 @@ class TestSqlAlchemyStore(unittest.TestCase):
         self.store.create_registered_model(name2)
         tags = ["tag1", "tag2", "tag3"]
         models = [
-            self.store.create_model_version(name1, "path/to/source1", "test", "application_1234"),
             self.store.create_model_version(
-                name1, "path/to/source2", "test", "application_1234", tags=[tags[0]]
+                name1, "path/to/source1", "test", "application_1234", "tensorflow"
+            ),
+            self.store.create_model_version(
+                name1, "path/to/source2", "test", "application_1234", "tensorflow", tags=[tags[0]]
             ),
             self.store.create_model_version(
-                name1, "path/to/source3", "test", "application_1234", tags=[tags[1]]
+                name1, "path/to/source3", "test", "application_1234", "tensorflow", tags=[tags[1]]
             ),
             self.store.create_model_version(
-                name1, "path/to/source4", "test", "application_1234", tags=[tags[0], tags[2]]
+                name1,
+                "path/to/source4",
+                "test",
+                "application_1234",
+                "tensorflow",
+                tags=[tags[0], tags[2]],
             ),
             self.store.create_model_version(
-                name1, "path/to/source5", "test", "application_1234", tags=tags
+                name1, "path/to/source5", "test", "application_1234", "tensorflow", tags=tags
             ),
         ]
 
@@ -623,7 +648,9 @@ class TestSqlAlchemyStore(unittest.TestCase):
     def test_get_model_version_uri(self):
         name = "test_get_model_version_uri"
         self.store.create_registered_model(name)
-        mv = self.store.create_model_version(name, "path/to/source", "test", "application_1234")
+        mv = self.store.create_model_version(
+            name, "path/to/source", "test", "application_1234", "tensorflow"
+        )
         uri = self.store.get_model_version_uri(mv.name, mv.version)
         self.assertEqual(uri, "path/to/source")
 
@@ -645,13 +672,13 @@ class TestSqlAlchemyStore(unittest.TestCase):
         self.store.create_registered_model(name1)
         self.store.create_registered_model(name2)
         rm1mv1 = self.store.create_model_version(
-            name1, "path/to/source1", "test", "application_1234", tags=tags
+            name1, "path/to/source1", "test", "application_1234", "tensorflow", tags=tags
         )
         rm1m2 = self.store.create_model_version(
-            name1, "path/to/source2", "test", "application_1234", tags=tags
+            name1, "path/to/source2", "test", "application_1234", "tensorflow", tags=tags
         )
         rm2mv1 = self.store.create_model_version(
-            name2, "path/to/source3", "test", "application_1234", tags=tags
+            name2, "path/to/source3", "test", "application_1234", "tensorflow", tags=tags
         )
         new_tag = "new tag"
         self.store.add_model_version_tag(rm1mv1.name, rm1mv1.version, new_tag)
@@ -692,13 +719,13 @@ class TestSqlAlchemyStore(unittest.TestCase):
         self.store.create_registered_model(name1)
         self.store.create_registered_model(name2)
         rm1mv1 = self.store.create_model_version(
-            name1, "path/to/source1", "test", "application_1234", tags=tags
+            name1, "path/to/source1", "test", "application_1234", "tensorflow", tags=tags
         )
         rm1m2 = self.store.create_model_version(
-            name1, "path/to/source2", "test", "application_1234", tags=tags
+            name1, "path/to/source2", "test", "application_1234", "tensorflow", tags=tags
         )
         rm2mv1 = self.store.create_model_version(
-            name2, "path/to/source3", "test", "application_1234", tags=tags
+            name2, "path/to/source3", "test", "application_1234", "tensorflow", tags=tags
         )
         new_tag = "new tag"
         self.store.add_model_version_tag(rm1mv1.name, rm1mv1.version, new_tag)
diff --git a/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/entities/ModelVersionEntity.java b/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/entities/ModelVersionEntity.java
index ade2cb7..a66b6a3 100644
--- a/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/entities/ModelVersionEntity.java
+++ b/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/entities/ModelVersionEntity.java
@@ -33,6 +33,8 @@ public class ModelVersionEntity {
 
   private String experimentId;
 
+  private String modelType;
+
   private String currentStage;
 
   private Timestamp creationTime;
@@ -85,6 +87,14 @@ public class ModelVersionEntity {
     this.experimentId = experimentId;
   }
 
+  public String getModelType() {
+    return modelType;
+  }
+
+  public void setModelType(String modelType) {
+    this.modelType = modelType;
+  }
+
   public String getCurrentStage() {
     return currentStage;
   }
@@ -142,6 +152,7 @@ public class ModelVersionEntity {
       ", source='" + source + '\'' +
       ", userId='" + userId + '\'' +
       ", experimentId='" + experimentId + '\'' +
+      ", modelType='" + modelType + '\'' +
       ", currentStage='" + currentStage + '\'' +
       ", creationTime='" + creationTime + '\'' +
       ", lastUpdatedTime=" + lastUpdatedTime + '\'' +
diff --git a/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ModelVersionMapper.xml b/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ModelVersionMapper.xml
index eda962e..2f3b344 100644
--- a/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ModelVersionMapper.xml
+++ b/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ModelVersionMapper.xml
@@ -25,6 +25,7 @@
     <result column="source" property="source" />
     <result column="user_id" property="userId" />
     <result column="experiment_id" property="experimentId" />
+    <result column="model_type" property="modelType" />
     <result column="current_stage" property="currentStage" />
     <result column="creation_time" property="creationTime" />
     <result column="last_updated_time" property="lastUpdatedTime" />
@@ -38,6 +39,7 @@
     <result column="source" property="source" />
     <result column="user_id" property="userId" />
     <result column="experiment_id" property="experimentId" />
+    <result column="model_type" property="modelType" />
     <result column="current_stage" property="currentStage" />
     <result column="creation_time" property="creationTime" />
     <result column="last_updated_time" property="lastUpdatedTime" />
@@ -49,7 +51,7 @@
   </resultMap>
 
   <sql id="Base_Column_List">
-    name, version, source, user_id, experiment_id, current_stage, creation_time,
+    name, version, source, user_id, experiment_id, model_type, current_stage, creation_time,
     last_updated_time, dataset, description
   </sql>
 
@@ -75,9 +77,9 @@
   </select>
 
   <insert id="insert" parameterType="org.apache.submarine.server.model.database.entities.ModelVersionEntity">
-    insert into model_version (name, version, source, user_id, experiment_id, current_stage, creation_time, last_updated_time, dataset, description)
+    insert into model_version (name, version, source, user_id, experiment_id, model_type, current_stage, creation_time, last_updated_time, dataset, description)
     values (#{name,jdbcType=VARCHAR}, #{version,jdbcType=INTEGER}, #{source,jdbcType=VARCHAR},
-    #{userId,jdbcType=VARCHAR}, #{experimentId,jdbcType=VARCHAR}, #{currentStage,jdbcType=VARCHAR},
+    #{userId,jdbcType=VARCHAR}, #{experimentId,jdbcType=VARCHAR}, #{modelType,jdbcType=VARCHAR}, #{currentStage,jdbcType=VARCHAR},
     NOW(3), NOW(3), #{dataset,jdbcType=VARCHAR}, #{description,jdbcType=VARCHAR});
     <if test="tags != null and !tags.isEmpty()">
       insert INTO model_version_tag (name, version, tag) values
diff --git a/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTagTest.java b/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTagTest.java
index 16f671a..72e5253 100644
--- a/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTagTest.java
+++ b/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTagTest.java
@@ -57,6 +57,7 @@ public class ModelVersionTagTest {
     modelVersionEntity.setSource("path/to/source");
     modelVersionEntity.setUserId("test");
     modelVersionEntity.setExperimentId("application_1234");
+    modelVersionEntity.setModelType("tensorflow");
     modelVersionService.insert(modelVersionEntity);
 
     ModelVersionTagEntity modelVersionTagEntity = new ModelVersionTagEntity();
diff --git a/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTest.java b/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTest.java
index 98ba37c..3c089b2 100644
--- a/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTest.java
+++ b/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTest.java
@@ -19,17 +19,17 @@
 
 package org.apache.submarine.server.model.database;
 
-import org.apache.submarine.server.model.database.entities.ModelVersionEntity;
-import org.apache.submarine.server.model.database.entities.RegisteredModelEntity;
-import org.apache.submarine.server.model.database.service.ModelVersionService;
-import org.apache.submarine.server.model.database.service.RegisteredModelService;
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Test;
-
 import java.util.ArrayList;
 import java.util.List;
 
+import org.apache.submarine.server.model.database.entities.ModelVersionEntity;
+import org.apache.submarine.server.model.database.entities.RegisteredModelEntity;
+import org.apache.submarine.server.model.database.service.ModelVersionService;
+import org.apache.submarine.server.model.database.service.RegisteredModelService;
+
 public class ModelVersionTest {
   RegisteredModelService registeredModelService = new RegisteredModelService();
   ModelVersionService modelVersionService = new ModelVersionService();
@@ -55,6 +55,7 @@ public class ModelVersionTest {
     modelVersionEntity.setSource("path/to/source");
     modelVersionEntity.setUserId("test");
     modelVersionEntity.setExperimentId("application_1234");
+    modelVersionEntity.setModelType("tensorflow");
     modelVersionEntity.setTags(tags);
     modelVersionService.insert(modelVersionEntity);
 
@@ -67,6 +68,7 @@ public class ModelVersionTest {
     modelVersionEntity2.setSource("path/to/source2");
     modelVersionEntity2.setUserId("test");
     modelVersionEntity2.setExperimentId("application_1234");
+    modelVersionEntity2.setModelType("tensorflow");
     modelVersionEntity2.setTags(tags2);
     modelVersionService.insert(modelVersionEntity2);
 
@@ -93,6 +95,7 @@ public class ModelVersionTest {
     modelVersionEntity.setSource("path/to/source");
     modelVersionEntity.setUserId("test");
     modelVersionEntity.setExperimentId("application_1234");
+    modelVersionEntity.setModelType("tensorflow");
     modelVersionEntity.setTags(tags);
     modelVersionService.insert(modelVersionEntity);
 
@@ -118,6 +121,7 @@ public class ModelVersionTest {
     modelVersionEntity.setSource("path/to/source");
     modelVersionEntity.setUserId("test");
     modelVersionEntity.setExperimentId("application_1234");
+    modelVersionEntity.setModelType("tensorflow");
     modelVersionService.insert(modelVersionEntity);
 
     ModelVersionEntity modelVersionEntitySelected = modelVersionService.select(name, version);
@@ -149,6 +153,7 @@ public class ModelVersionTest {
     modelVersionEntity.setSource("path/to/source");
     modelVersionEntity.setUserId("test");
     modelVersionEntity.setExperimentId("application_1234");
+    modelVersionEntity.setModelType("tensorflow");
     modelVersionService.insert(modelVersionEntity);
 
     modelVersionService.delete(name, version);
@@ -161,6 +166,7 @@ public class ModelVersionTest {
     Assert.assertEquals(expected.getSource(), actual.getSource());
     Assert.assertEquals(expected.getUserId(), actual.getUserId());
     Assert.assertEquals(expected.getExperimentId(), actual.getExperimentId());
+    Assert.assertEquals(expected.getModelType(), actual.getModelType());
     Assert.assertEquals(expected.getCurrentStage(), actual.getCurrentStage());
     Assert.assertNotNull(actual.getCreationTime());
     Assert.assertNotNull(actual.getLastUpdatedTime());
diff --git a/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java b/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java
index ea682ef..ca2e770 100644
--- a/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java
+++ b/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java
@@ -50,6 +50,7 @@ public class ModelVersionRestApiTest {
   private final String modelVersionSource = "s3://submarine/test";
   private final String modelVersionUid = "test123";
   private final String modelVersionExperimentId = "experiment_123";
+  private final String modelVersionModelType = "experiment_123";
   private final String modelVersionTag = "testTag";
 
   private final RegisteredModelService registeredModelService = new RegisteredModelService();
@@ -78,6 +79,7 @@ public class ModelVersionRestApiTest {
     modelVersion1.setSource(modelVersionSource + "1");
     modelVersion1.setUserId(modelVersionUid);
     modelVersion1.setExperimentId(modelVersionExperimentId);
+    modelVersion1.setModelType(modelVersionModelType);
     modelVersionService.insert(modelVersion1);
     modelVersion2.setName(registeredModelName);
     modelVersion2.setDescription(modelVersionDescription + "2");
@@ -85,6 +87,7 @@ public class ModelVersionRestApiTest {
     modelVersion2.setSource(modelVersionSource + "2");
     modelVersion2.setUserId(modelVersionUid);
     modelVersion2.setExperimentId(modelVersionExperimentId);
+    modelVersion2.setModelType(modelVersionModelType);
     modelVersionService.insert(modelVersion2);
   }
 
@@ -174,5 +177,6 @@ public class ModelVersionRestApiTest {
     assertEquals(result.getVersion(), actual.getVersion());
     assertEquals(result.getSource(), actual.getSource());
     assertEquals(result.getExperimentId(), actual.getExperimentId());
+    assertEquals(result.getModelType(), actual.getModelType());
   }
 }
diff --git a/website/docs/userDocs/submarine-sdk/tracking.md b/website/docs/userDocs/submarine-sdk/tracking.md
index f774753..861c394 100644
--- a/website/docs/userDocs/submarine-sdk/tracking.md
+++ b/website/docs/userDocs/submarine-sdk/tracking.md
@@ -63,3 +63,15 @@ log a single key-value metric. The value must always be a number.
   - **key** - Metric name.
   - **value** - Metric value.
   - **step** - A single integer step at which to log the specified Metrics, by default it's 0.
+
+### `submarine.save_model(model_type: str, model, artifact_path: str, registered_model_name: str = None, input_dim: list = None, output_dim: list = None,) -> None`
+
+ Save a model into the minio pod.
+
+> **Parameters**
+  - **model_type** - The type of model. Only support `pytorch` and `tensorflow`.
+  - **model** - Model artifact.
+  - **artifact_path** - Model name.
+  - **registered_model_name** - If it is not `None`, the model will be registered into the model registry with this name.
+  - **input_dim** - The input dimension of the model.
+  - **output_dim** - The output dimension of the model.

---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org