You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by ku...@apache.org on 2021/11/23 05:17:38 UTC
[submarine] branch master updated: SUBMARINE-1077. Model descriptive file
This is an automated email from the ASF dual-hosted git repository.
kuanhsun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push:
new b4de88b SUBMARINE-1077. Model descriptive file
b4de88b is described below
commit b4de88b08f19285a561f27911c86877fa4e18b13
Author: jeff-901 <b0...@ntu.edu.tw>
AuthorDate: Sat Nov 20 09:24:08 2021 +0800
SUBMARINE-1077. Model descriptive file
### What is this PR for?
Create a file under each artifact. This file contains model type, model input dimensions and model output dimensions.
Also, change the artifact save path in minio pod for serving.
### What type of PR is it?
Feature
### Todos
### What is the Jira issue?
https://issues.apache.org/jira/browse/SUBMARINE-1077
### How should this be tested?
### Screenshots (if appropriate)
### Questions:
* Do the license files need updating? No
* Are there breaking changes for older versions? No
* Does this need new documentation? Yes
Author: jeff-901 <b0...@ntu.edu.tw>
Signed-off-by: jingch1213717 <ku...@apache.org>
Closes #801 from jeff-901/SUBMARINE-1077 and squashes the following commits:
aff3ba11 [jeff-901] delete data_type and name in description file
2cd0a4a4 [jeff-901] fix style
312e237f [jeff-901] uodate document
234e989e [jeff-901] fix test
8b915708 [jeff-901] fix test
96f5b974 [jeff-901] add model type in db
eec42890 [jeff-901] add model type
72d7a8ae [jeff-901] fix style
2ee1ca8b [jeff-901] description file
---
dev-support/database/submarine-model.sql | 1 +
.../pysubmarine/submarine/artifacts/repository.py | 15 ++---
.../entities/model_registry/model_version.py | 7 ++
.../pysubmarine/submarine/store/database/models.py | 6 ++
.../store/model_registry/abstract_store.py | 1 +
.../store/model_registry/sqlalchemy_store.py | 2 +
.../pysubmarine/submarine/tracking/client.py | 48 ++++++++++++--
.../entities/model_registry/test_model_version.py | 7 ++
.../store/model_registry/test_sqlalchemy_store.py | 75 +++++++++++++++-------
.../database/entities/ModelVersionEntity.java | 11 ++++
.../database/mappers/ModelVersionMapper.xml | 8 ++-
.../server/model/database/ModelVersionTagTest.java | 1 +
.../server/model/database/ModelVersionTest.java | 16 +++--
.../server/rest/ModelVersionRestApiTest.java | 4 ++
website/docs/userDocs/submarine-sdk/tracking.md | 12 ++++
15 files changed, 169 insertions(+), 45 deletions(-)
diff --git a/dev-support/database/submarine-model.sql b/dev-support/database/submarine-model.sql
index abad0bd..3ace957 100644
--- a/dev-support/database/submarine-model.sql
+++ b/dev-support/database/submarine-model.sql
@@ -36,6 +36,7 @@ CREATE TABLE `model_version` (
`source` VARCHAR(512) NOT NULL COMMENT 'Model saved link',
`user_id` VARCHAR(64) NOT NULL COMMENT 'Id of the created user',
`experiment_id` VARCHAR(64) NOT NULL,
+ `model_type` VARCHAR(64) NOT NULL COMMENT 'Type of model',
`current_stage` VARCHAR(64) COMMENT 'Model stage ex: None, production...',
`creation_time` DATETIME(3) COMMENT 'Millisecond precision',
`last_updated_time` DATETIME(3) COMMENT 'Millisecond precision',
diff --git a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
index 9bee7ae..30622c2 100644
--- a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
+++ b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
@@ -27,31 +27,30 @@ class Repository:
endpoint_url=os.environ.get("MLFLOW_S3_ENDPOINT_URL"),
)
self.dest_path = experiment_id
+ self.bucket = "submarine"
def _upload_file(self, local_file: str, bucket: str, key: str) -> None:
self.client.upload_file(Filename=local_file, Bucket=bucket, Key=key)
def _list_artifact_subfolder(self, artifact_path: str):
response = self.client.list_objects(
- Bucket="submarine",
+ Bucket=self.bucket,
Prefix=os.path.join(self.dest_path, artifact_path) + "/",
Delimiter="/",
)
return response.get("CommonPrefixes")
def log_artifact(self, local_file: str, artifact_path: str) -> None:
- bucket = "submarine"
dest_path = self.dest_path
dest_path = os.path.join(dest_path, artifact_path)
dest_path = os.path.join(dest_path, os.path.basename(local_file))
self._upload_file(
local_file=local_file,
- bucket=bucket,
+ bucket=self.bucket,
key=dest_path,
)
def log_artifacts(self, local_dir: str, artifact_path: str) -> str:
- bucket = "submarine"
dest_path = self.dest_path
list_of_subfolder = self._list_artifact_subfolder(artifact_path)
if list_of_subfolder is None:
@@ -68,16 +67,16 @@ class Repository:
for f in filenames:
self._upload_file(
local_file=os.path.join(root, f),
- bucket=bucket,
+ bucket=self.bucket,
key=os.path.join(upload_path, f),
)
- return f"s3://{bucket}/{dest_path}"
+ return f"s3://{self.bucket}/{dest_path}"
def delete_folder(self) -> None:
- objects_to_delete = self.client.list_objects(Bucket="submarine", Prefix=self.dest_path)
+ objects_to_delete = self.client.list_objects(Bucket=self.bucket, Prefix=self.dest_path)
if objects_to_delete.get("Contents") is not None:
delete_keys: dict = {"Objects": []}
delete_keys["Objects"] = [
{"Key": k} for k in [obj["Key"] for obj in objects_to_delete.get("Contents")]
]
- self.client.delete_objects(Bucket="submarine", Delete=delete_keys)
+ self.client.delete_objects(Bucket=self.bucket, Delete=delete_keys)
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
index 0b43e0a..4d2f3fb 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
@@ -28,6 +28,7 @@ class ModelVersion(_SubmarineObject):
source,
user_id,
experiment_id,
+ model_type,
current_stage,
creation_time,
last_updated_time,
@@ -40,6 +41,7 @@ class ModelVersion(_SubmarineObject):
self._source = source
self._user_id = user_id
self._experiment_id = experiment_id
+ self._model_type = model_type
self._current_stage = current_stage
self._creation_time = creation_time
self._last_updated_time = last_updated_time
@@ -73,6 +75,11 @@ class ModelVersion(_SubmarineObject):
return self._experiment_id
@property
+ def model_type(self):
+ """String. Type of model."""
+ return self._model_type
+
+ @property
def creation_time(self):
"""Datetime object. The creation datetime of this version."""
return self._creation_time
diff --git a/submarine-sdk/pysubmarine/submarine/store/database/models.py b/submarine-sdk/pysubmarine/submarine/store/database/models.py
index ff55c2a..70bb75e 100644
--- a/submarine-sdk/pysubmarine/submarine/store/database/models.py
+++ b/submarine-sdk/pysubmarine/submarine/store/database/models.py
@@ -196,6 +196,11 @@ class SqlModelVersion(Base):
ID to which this version of model belongs to.
"""
+ model_type = Column(String(64), nullable=False)
+ """
+ Type of model.
+ """
+
current_stage = Column(String(64), default=STAGE_NONE)
"""
Current stage of this version of model: it can be `None`, `Developing`,
@@ -254,6 +259,7 @@ class SqlModelVersion(Base):
source=self.source,
user_id=self.user_id,
experiment_id=self.experiment_id,
+ model_type=self.model_type,
current_stage=self.current_stage,
creation_time=self.creation_time,
last_updated_time=self.last_updated_time,
diff --git a/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py b/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
index a5e81ce..a06ed99 100644
--- a/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
+++ b/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
@@ -129,6 +129,7 @@ class AbstractStore:
source: str,
user_id: str,
experiment_id: str,
+ model_type: str,
dataset: str = None,
description: str = None,
tags: List[str] = None,
diff --git a/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py b/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
index 47fb075..3825716 100644
--- a/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
@@ -343,6 +343,7 @@ class SqlAlchemyStore(AbstractStore):
source: str,
user_id: str,
experiment_id: str,
+ model_type: str,
dataset: str = None,
description: str = None,
tags: List[str] = None,
@@ -380,6 +381,7 @@ class SqlAlchemyStore(AbstractStore):
source=source,
user_id=user_id,
experiment_id=experiment_id,
+ model_type=model_type,
creation_time=creation_time,
last_updated_time=creation_time,
dataset=dataset,
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/client.py b/submarine-sdk/pysubmarine/submarine/tracking/client.py
index 982dfe8..25ed6b7 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/client.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/client.py
@@ -12,10 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import os
import re
import tempfile
import time
+from typing import Any, Dict
import submarine
from submarine.artifacts.repository import Repository
@@ -92,7 +94,13 @@ class SubmarineClient(object):
self.store.log_param(job_id, param)
def save_model(
- self, model_type: str, model, artifact_path: str, registered_model_name: str = None
+ self,
+ model_type: str,
+ model,
+ artifact_path: str,
+ registered_model_name: str = None,
+ input_dim: list = None,
+ output_dim: list = None,
) -> None:
"""
Save a model into the minio pod.
@@ -101,24 +109,53 @@ class SubmarineClient(object):
:param artifact_path: Relative path of the artifact in the minio pod.
:param registered_model_name: If not None, register model into the model registry with
this name. If None, the model only be saved in minio pod.
+ :param input_dim: Save the input dimension of the given model to the description file.
+ :param output_dim: Save the output dimension of the given model to the description file.
"""
pattern = r"[0-9A-Za-z][0-9A-Za-z-_]*[0-9A-Za-z]|[0-9A-Za-z]"
if not re.fullmatch(pattern, artifact_path):
raise Exception(
- "Artifact_path must only contains numbers, characters, hyphen and underscore. "
- " Artifact_path must starts and ends with numbers or characters."
+ "Artifact_path must only contains numbers, characters, hyphen and underscore. "
+ "Artifact_path must starts and ends with numbers or characters."
)
with tempfile.TemporaryDirectory() as tempdir:
+ description: Dict[str, Any] = dict()
+ model_save_dir = os.path.join(tempdir, "1")
+ os.mkdir(model_save_dir)
if model_type == "pytorch":
import submarine.models.pytorch
- submarine.models.pytorch.save_model(model, tempdir)
+ if input_dim is None or output_dim is None:
+ raise Exception(
+ "Saving pytorch model needs to provide input and output dimension for"
+ " serving."
+ )
+ submarine.models.pytorch.save_model(model, model_save_dir)
elif model_type == "tensorflow":
import submarine.models.tensorflow
- submarine.models.tensorflow.save_model(model, tempdir)
+ submarine.models.tensorflow.save_model(model, model_save_dir)
else:
raise Exception("No valid type of model has been matched to {}".format(model_type))
+
+ # Write description file
+ if input_dim is not None:
+ description["input"] = [
+ {
+ "dims": input_dim,
+ }
+ ]
+ if output_dim is not None:
+ description["output"] = [
+ {
+ "dims": output_dim,
+ }
+ ]
+ description["model_type"] = model_type
+ with open(os.path.join(tempdir, "description.json"), "w") as f:
+ json.dump(description, f)
+
+ # Log all files into minio
source = self.artifact_repo.log_artifacts(tempdir, artifact_path)
# Register model
@@ -132,4 +169,5 @@ class SubmarineClient(object):
source=source,
user_id="", # TODO(jeff-901): the user id is needed to be specified.
experiment_id=utils.get_job_id(),
+ model_type=model_type,
)
diff --git a/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py b/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py
index 9cead79..1379c08 100644
--- a/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py
+++ b/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py
@@ -26,6 +26,7 @@ class TestModelVersion:
"source": "path/to/source",
"user_id": "admin",
"experiment_id": "experiment_1",
+ "model_type": "tensorflow",
"current_stage": STAGE_NONE,
"creation_time": datetime.now(),
"last_updated_time": datetime.now(),
@@ -42,6 +43,7 @@ class TestModelVersion:
source,
user_id,
experiment_id,
+ model_type,
current_stage,
creation_time,
last_updated_time,
@@ -55,6 +57,7 @@ class TestModelVersion:
assert model_metadata.source == source
assert model_metadata.user_id == user_id
assert model_metadata.experiment_id == experiment_id
+ assert model_metadata.model_type == model_type
assert model_metadata.current_stage == current_stage
assert model_metadata.creation_time == creation_time
assert model_metadata.last_updated_time == last_updated_time
@@ -69,6 +72,7 @@ class TestModelVersion:
self.default_data["source"],
self.default_data["user_id"],
self.default_data["experiment_id"],
+ self.default_data["model_type"],
self.default_data["current_stage"],
self.default_data["creation_time"],
self.default_data["last_updated_time"],
@@ -83,6 +87,7 @@ class TestModelVersion:
self.default_data["source"],
self.default_data["user_id"],
self.default_data["experiment_id"],
+ self.default_data["model_type"],
self.default_data["current_stage"],
self.default_data["creation_time"],
self.default_data["last_updated_time"],
@@ -101,6 +106,7 @@ class TestModelVersion:
self.default_data["source"],
self.default_data["user_id"],
self.default_data["experiment_id"],
+ self.default_data["model_type"],
self.default_data["current_stage"],
self.default_data["creation_time"],
self.default_data["last_updated_time"],
@@ -115,6 +121,7 @@ class TestModelVersion:
self.default_data["source"],
self.default_data["user_id"],
self.default_data["experiment_id"],
+ self.default_data["model_type"],
self.default_data["current_stage"],
self.default_data["creation_time"],
self.default_data["last_updated_time"],
diff --git a/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py b/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py
index 6d57c2e..afe3217 100644
--- a/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py
@@ -107,8 +107,12 @@ class TestSqlAlchemyStore(unittest.TestCase):
name = "test_rename_RM"
new_name = "test_rename_RN_new"
rm = self.store.create_registered_model(name)
- self.store.create_model_version(name, "path/to/source1", "test", "application_1234")
- self.store.create_model_version(name, "path/to/source2", "test", "application_1235")
+ self.store.create_model_version(
+ name, "path/to/source1", "test", "application_1234", "tensorflow"
+ )
+ self.store.create_model_version(
+ name, "path/to/source2", "test", "application_1235", "tensorflow"
+ )
mv1d = self.store.get_model_version(name, 1)
mv2d = self.store.get_model_version(name, 2)
self.assertEqual(rm.name, name)
@@ -142,10 +146,10 @@ class TestSqlAlchemyStore(unittest.TestCase):
rm2 = self.store.create_registered_model(name2, tags=rm_tags)
mv_tags = ["mv_tag1", "mv_tag2"]
rm1mv1 = self.store.create_model_version(
- rm1.name, "path/to/source1", "test", "application_1234", tags=mv_tags
+ rm1.name, "path/to/source1", "test", "application_1234", "tensorflow", tags=mv_tags
)
rm2mv1 = self.store.create_model_version(
- rm2.name, "path/to/source2", "test", "application_1234", tags=mv_tags
+ rm2.name, "path/to/source2", "test", "application_1234", "tensorflow", tags=mv_tags
)
# check store
@@ -380,7 +384,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
self.store.create_registered_model(model_name)
fake_datetime = datetime.now()
mv1 = self.store.create_model_version(
- model_name, "path/to/source1", "test", "application_1234"
+ model_name, "path/to/source1", "test", "application_1234", "tensorflow"
)
self.assertEqual(mv1.name, model_name)
self.assertEqual(mv1.version, 1)
@@ -390,6 +394,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
self.assertEqual(m1d.name, model_name)
self.assertEqual(m1d.user_id, "test")
self.assertEqual(m1d.experiment_id, "application_1234")
+ self.assertEqual(m1d.model_type, "tensorflow")
self.assertEqual(m1d.current_stage, STAGE_NONE)
self.assertEqual(m1d.creation_time, fake_datetime)
self.assertEqual(m1d.last_updated_time, fake_datetime)
@@ -398,7 +403,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
# new model for same registered model autoincrement version
m2 = self.store.create_model_version(
- model_name, "path/to/source2", "test", "application_1234"
+ model_name, "path/to/source2", "test", "application_1234", "tensorflow"
)
m2d = self.store.get_model_version(m2.name, m2.version)
self.assertEqual(m2.version, 2)
@@ -407,7 +412,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
# create model with tags
tags = ["tag1", "tag2"]
m3 = self.store.create_model_version(
- model_name, "path/to/source3", "test", "application_1234", tags=tags
+ model_name, "path/to/source3", "test", "application_1234", "tensorflow", tags=tags
)
m3d = self.store.get_model_version(m3.name, m3.version)
self.assertEqual(m3.version, 3)
@@ -418,7 +423,12 @@ class TestSqlAlchemyStore(unittest.TestCase):
# create model with description
description = "A test description."
m4 = self.store.create_model_version(
- model_name, "path/to/source4", "test", "application_1234", description=description
+ model_name,
+ "path/to/source4",
+ "test",
+ "application_1234",
+ "tensorflow",
+ description=description,
)
m4d = self.store.get_model_version(m4.name, m4.version)
self.assertEqual(m4.version, 4)
@@ -429,7 +439,9 @@ class TestSqlAlchemyStore(unittest.TestCase):
def test_update_model_version_description(self):
name = "test_update_MV_description"
self.store.create_registered_model(name)
- mv1 = self.store.create_model_version(name, "path/to/source", "test", "application_1234")
+ mv1 = self.store.create_model_version(
+ name, "path/to/source", "test", "application_1234", "tensorflow"
+ )
m1d = self.store.get_model_version(mv1.name, mv1.version)
self.assertEqual(m1d.name, name)
self.assertEqual(m1d.version, 1)
@@ -448,8 +460,12 @@ class TestSqlAlchemyStore(unittest.TestCase):
def test_transition_model_version_stage(self):
name = "test_transition_MV_stage"
self.store.create_registered_model(name)
- mv1 = self.store.create_model_version(name, "path/to/source1", "test", "application_1234")
- m2 = self.store.create_model_version(name, "path/to/source2", "test", "application_1234")
+ mv1 = self.store.create_model_version(
+ name, "path/to/source1", "test", "application_1234", "tensorflow"
+ )
+ m2 = self.store.create_model_version(
+ name, "path/to/source2", "test", "application_1234", "tensorflow"
+ )
fake_datetime = datetime.strptime("2021-11-11 11:11:11.111000", "%Y-%m-%d %H:%M:%S.%f")
with freeze_time(fake_datetime):
@@ -508,7 +524,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
tags = ["tag1", "tag2"]
self.store.create_registered_model(name)
mv = self.store.create_model_version(
- name, "path/to/source", "test", "application_1234", tags=tags
+ name, "path/to/source", "test", "application_1234", "tensorflow", tags=tags
)
mvd = self.store.get_model_version(mv.name, mv.version)
self.assertEqual(mvd.name, name)
@@ -548,6 +564,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
source="path/to/source",
user_id="test",
experiment_id="application_1234",
+ model_type="tensorflow",
tags=tags,
)
self.assertEqual(mv.creation_time, fake_datetime)
@@ -556,6 +573,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
self.assertEqual(mvd.name, name)
self.assertEqual(mvd.user_id, "test")
self.assertEqual(mvd.experiment_id, "application_1234")
+ self.assertEqual(mvd.model_type, "tensorflow")
self.assertEqual(mvd.current_stage, STAGE_NONE)
self.assertEqual(mvd.creation_time, fake_datetime)
self.assertEqual(mvd.last_updated_time, fake_datetime)
@@ -578,18 +596,25 @@ class TestSqlAlchemyStore(unittest.TestCase):
self.store.create_registered_model(name2)
tags = ["tag1", "tag2", "tag3"]
models = [
- self.store.create_model_version(name1, "path/to/source1", "test", "application_1234"),
self.store.create_model_version(
- name1, "path/to/source2", "test", "application_1234", tags=[tags[0]]
+ name1, "path/to/source1", "test", "application_1234", "tensorflow"
+ ),
+ self.store.create_model_version(
+ name1, "path/to/source2", "test", "application_1234", "tensorflow", tags=[tags[0]]
),
self.store.create_model_version(
- name1, "path/to/source3", "test", "application_1234", tags=[tags[1]]
+ name1, "path/to/source3", "test", "application_1234", "tensorflow", tags=[tags[1]]
),
self.store.create_model_version(
- name1, "path/to/source4", "test", "application_1234", tags=[tags[0], tags[2]]
+ name1,
+ "path/to/source4",
+ "test",
+ "application_1234",
+ "tensorflow",
+ tags=[tags[0], tags[2]],
),
self.store.create_model_version(
- name1, "path/to/source5", "test", "application_1234", tags=tags
+ name1, "path/to/source5", "test", "application_1234", "tensorflow", tags=tags
),
]
@@ -623,7 +648,9 @@ class TestSqlAlchemyStore(unittest.TestCase):
def test_get_model_version_uri(self):
name = "test_get_model_version_uri"
self.store.create_registered_model(name)
- mv = self.store.create_model_version(name, "path/to/source", "test", "application_1234")
+ mv = self.store.create_model_version(
+ name, "path/to/source", "test", "application_1234", "tensorflow"
+ )
uri = self.store.get_model_version_uri(mv.name, mv.version)
self.assertEqual(uri, "path/to/source")
@@ -645,13 +672,13 @@ class TestSqlAlchemyStore(unittest.TestCase):
self.store.create_registered_model(name1)
self.store.create_registered_model(name2)
rm1mv1 = self.store.create_model_version(
- name1, "path/to/source1", "test", "application_1234", tags=tags
+ name1, "path/to/source1", "test", "application_1234", "tensorflow", tags=tags
)
rm1m2 = self.store.create_model_version(
- name1, "path/to/source2", "test", "application_1234", tags=tags
+ name1, "path/to/source2", "test", "application_1234", "tensorflow", tags=tags
)
rm2mv1 = self.store.create_model_version(
- name2, "path/to/source3", "test", "application_1234", tags=tags
+ name2, "path/to/source3", "test", "application_1234", "tensorflow", tags=tags
)
new_tag = "new tag"
self.store.add_model_version_tag(rm1mv1.name, rm1mv1.version, new_tag)
@@ -692,13 +719,13 @@ class TestSqlAlchemyStore(unittest.TestCase):
self.store.create_registered_model(name1)
self.store.create_registered_model(name2)
rm1mv1 = self.store.create_model_version(
- name1, "path/to/source1", "test", "application_1234", tags=tags
+ name1, "path/to/source1", "test", "application_1234", "tensorflow", tags=tags
)
rm1m2 = self.store.create_model_version(
- name1, "path/to/source2", "test", "application_1234", tags=tags
+ name1, "path/to/source2", "test", "application_1234", "tensorflow", tags=tags
)
rm2mv1 = self.store.create_model_version(
- name2, "path/to/source3", "test", "application_1234", tags=tags
+ name2, "path/to/source3", "test", "application_1234", "tensorflow", tags=tags
)
new_tag = "new tag"
self.store.add_model_version_tag(rm1mv1.name, rm1mv1.version, new_tag)
diff --git a/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/entities/ModelVersionEntity.java b/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/entities/ModelVersionEntity.java
index ade2cb7..a66b6a3 100644
--- a/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/entities/ModelVersionEntity.java
+++ b/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/entities/ModelVersionEntity.java
@@ -33,6 +33,8 @@ public class ModelVersionEntity {
private String experimentId;
+ private String modelType;
+
private String currentStage;
private Timestamp creationTime;
@@ -85,6 +87,14 @@ public class ModelVersionEntity {
this.experimentId = experimentId;
}
+ public String getModelType() {
+ return modelType;
+ }
+
+ public void setModelType(String modelType) {
+ this.modelType = modelType;
+ }
+
public String getCurrentStage() {
return currentStage;
}
@@ -142,6 +152,7 @@ public class ModelVersionEntity {
", source='" + source + '\'' +
", userId='" + userId + '\'' +
", experimentId='" + experimentId + '\'' +
+ ", modelType='" + modelType + '\'' +
", currentStage='" + currentStage + '\'' +
", creationTime='" + creationTime + '\'' +
", lastUpdatedTime=" + lastUpdatedTime + '\'' +
diff --git a/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ModelVersionMapper.xml b/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ModelVersionMapper.xml
index eda962e..2f3b344 100644
--- a/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ModelVersionMapper.xml
+++ b/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ModelVersionMapper.xml
@@ -25,6 +25,7 @@
<result column="source" property="source" />
<result column="user_id" property="userId" />
<result column="experiment_id" property="experimentId" />
+ <result column="model_type" property="modelType" />
<result column="current_stage" property="currentStage" />
<result column="creation_time" property="creationTime" />
<result column="last_updated_time" property="lastUpdatedTime" />
@@ -38,6 +39,7 @@
<result column="source" property="source" />
<result column="user_id" property="userId" />
<result column="experiment_id" property="experimentId" />
+ <result column="model_type" property="modelType" />
<result column="current_stage" property="currentStage" />
<result column="creation_time" property="creationTime" />
<result column="last_updated_time" property="lastUpdatedTime" />
@@ -49,7 +51,7 @@
</resultMap>
<sql id="Base_Column_List">
- name, version, source, user_id, experiment_id, current_stage, creation_time,
+ name, version, source, user_id, experiment_id, model_type, current_stage, creation_time,
last_updated_time, dataset, description
</sql>
@@ -75,9 +77,9 @@
</select>
<insert id="insert" parameterType="org.apache.submarine.server.model.database.entities.ModelVersionEntity">
- insert into model_version (name, version, source, user_id, experiment_id, current_stage, creation_time, last_updated_time, dataset, description)
+ insert into model_version (name, version, source, user_id, experiment_id, model_type, current_stage, creation_time, last_updated_time, dataset, description)
values (#{name,jdbcType=VARCHAR}, #{version,jdbcType=INTEGER}, #{source,jdbcType=VARCHAR},
- #{userId,jdbcType=VARCHAR}, #{experimentId,jdbcType=VARCHAR}, #{currentStage,jdbcType=VARCHAR},
+ #{userId,jdbcType=VARCHAR}, #{experimentId,jdbcType=VARCHAR}, #{modelType,jdbcType=VARCHAR}, #{currentStage,jdbcType=VARCHAR},
NOW(3), NOW(3), #{dataset,jdbcType=VARCHAR}, #{description,jdbcType=VARCHAR});
<if test="tags != null and !tags.isEmpty()">
insert INTO model_version_tag (name, version, tag) values
diff --git a/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTagTest.java b/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTagTest.java
index 16f671a..72e5253 100644
--- a/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTagTest.java
+++ b/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTagTest.java
@@ -57,6 +57,7 @@ public class ModelVersionTagTest {
modelVersionEntity.setSource("path/to/source");
modelVersionEntity.setUserId("test");
modelVersionEntity.setExperimentId("application_1234");
+ modelVersionEntity.setModelType("tensorflow");
modelVersionService.insert(modelVersionEntity);
ModelVersionTagEntity modelVersionTagEntity = new ModelVersionTagEntity();
diff --git a/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTest.java b/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTest.java
index 98ba37c..3c089b2 100644
--- a/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTest.java
+++ b/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTest.java
@@ -19,17 +19,17 @@
package org.apache.submarine.server.model.database;
-import org.apache.submarine.server.model.database.entities.ModelVersionEntity;
-import org.apache.submarine.server.model.database.entities.RegisteredModelEntity;
-import org.apache.submarine.server.model.database.service.ModelVersionService;
-import org.apache.submarine.server.model.database.service.RegisteredModelService;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
-
import java.util.ArrayList;
import java.util.List;
+import org.apache.submarine.server.model.database.entities.ModelVersionEntity;
+import org.apache.submarine.server.model.database.entities.RegisteredModelEntity;
+import org.apache.submarine.server.model.database.service.ModelVersionService;
+import org.apache.submarine.server.model.database.service.RegisteredModelService;
+
public class ModelVersionTest {
RegisteredModelService registeredModelService = new RegisteredModelService();
ModelVersionService modelVersionService = new ModelVersionService();
@@ -55,6 +55,7 @@ public class ModelVersionTest {
modelVersionEntity.setSource("path/to/source");
modelVersionEntity.setUserId("test");
modelVersionEntity.setExperimentId("application_1234");
+ modelVersionEntity.setModelType("tensorflow");
modelVersionEntity.setTags(tags);
modelVersionService.insert(modelVersionEntity);
@@ -67,6 +68,7 @@ public class ModelVersionTest {
modelVersionEntity2.setSource("path/to/source2");
modelVersionEntity2.setUserId("test");
modelVersionEntity2.setExperimentId("application_1234");
+ modelVersionEntity2.setModelType("tensorflow");
modelVersionEntity2.setTags(tags2);
modelVersionService.insert(modelVersionEntity2);
@@ -93,6 +95,7 @@ public class ModelVersionTest {
modelVersionEntity.setSource("path/to/source");
modelVersionEntity.setUserId("test");
modelVersionEntity.setExperimentId("application_1234");
+ modelVersionEntity.setModelType("tensorflow");
modelVersionEntity.setTags(tags);
modelVersionService.insert(modelVersionEntity);
@@ -118,6 +121,7 @@ public class ModelVersionTest {
modelVersionEntity.setSource("path/to/source");
modelVersionEntity.setUserId("test");
modelVersionEntity.setExperimentId("application_1234");
+ modelVersionEntity.setModelType("tensorflow");
modelVersionService.insert(modelVersionEntity);
ModelVersionEntity modelVersionEntitySelected = modelVersionService.select(name, version);
@@ -149,6 +153,7 @@ public class ModelVersionTest {
modelVersionEntity.setSource("path/to/source");
modelVersionEntity.setUserId("test");
modelVersionEntity.setExperimentId("application_1234");
+ modelVersionEntity.setModelType("tensorflow");
modelVersionService.insert(modelVersionEntity);
modelVersionService.delete(name, version);
@@ -161,6 +166,7 @@ public class ModelVersionTest {
Assert.assertEquals(expected.getSource(), actual.getSource());
Assert.assertEquals(expected.getUserId(), actual.getUserId());
Assert.assertEquals(expected.getExperimentId(), actual.getExperimentId());
+ Assert.assertEquals(expected.getModelType(), actual.getModelType());
Assert.assertEquals(expected.getCurrentStage(), actual.getCurrentStage());
Assert.assertNotNull(actual.getCreationTime());
Assert.assertNotNull(actual.getLastUpdatedTime());
diff --git a/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java b/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java
index ea682ef..ca2e770 100644
--- a/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java
+++ b/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java
@@ -50,6 +50,7 @@ public class ModelVersionRestApiTest {
private final String modelVersionSource = "s3://submarine/test";
private final String modelVersionUid = "test123";
private final String modelVersionExperimentId = "experiment_123";
+ private final String modelVersionModelType = "experiment_123";
private final String modelVersionTag = "testTag";
private final RegisteredModelService registeredModelService = new RegisteredModelService();
@@ -78,6 +79,7 @@ public class ModelVersionRestApiTest {
modelVersion1.setSource(modelVersionSource + "1");
modelVersion1.setUserId(modelVersionUid);
modelVersion1.setExperimentId(modelVersionExperimentId);
+ modelVersion1.setModelType(modelVersionModelType);
modelVersionService.insert(modelVersion1);
modelVersion2.setName(registeredModelName);
modelVersion2.setDescription(modelVersionDescription + "2");
@@ -85,6 +87,7 @@ public class ModelVersionRestApiTest {
modelVersion2.setSource(modelVersionSource + "2");
modelVersion2.setUserId(modelVersionUid);
modelVersion2.setExperimentId(modelVersionExperimentId);
+ modelVersion2.setModelType(modelVersionModelType);
modelVersionService.insert(modelVersion2);
}
@@ -174,5 +177,6 @@ public class ModelVersionRestApiTest {
assertEquals(result.getVersion(), actual.getVersion());
assertEquals(result.getSource(), actual.getSource());
assertEquals(result.getExperimentId(), actual.getExperimentId());
+ assertEquals(result.getModelType(), actual.getModelType());
}
}
diff --git a/website/docs/userDocs/submarine-sdk/tracking.md b/website/docs/userDocs/submarine-sdk/tracking.md
index f774753..861c394 100644
--- a/website/docs/userDocs/submarine-sdk/tracking.md
+++ b/website/docs/userDocs/submarine-sdk/tracking.md
@@ -63,3 +63,15 @@ log a single key-value metric. The value must always be a number.
- **key** - Metric name.
- **value** - Metric value.
- **step** - A single integer step at which to log the specified Metrics, by default it's 0.
+
+### `submarine.save_model(model_type: str, model, artifact_path: str, registered_model_name: str = None, input_dim: list = None, output_dim: list = None,) -> None`
+
+ Save a model into the minio pod.
+
+> **Parameters**
+ - **model_type** - The type of model. Only support `pytorch` and `tensorflow`.
+ - **model** - Model artifact.
+ - **artifact_path** - Model name.
+ - **registered_model_name** - If it is not `None`, the model will be registered into the model registry with this name.
+ - **input_dim** - The input dimension of the model.
+ - **output_dim** - The output dimension of the model.
---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org