You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by pi...@apache.org on 2022/02/15 04:33:13 UTC
[submarine] branch master updated: SUBMARINE-1182. Use the unique source for model management and use it for later serving
This is an automated email from the ASF dual-hosted git repository.
pingsutw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push:
new 1dc6757 SUBMARINE-1182. Use the unique source for model management and use it for later serving
1dc6757 is described below
commit 1dc675785b6296d034a517127aee2a79fa57dda5
Author: KUAN-HSUN-LI <b0...@ntu.edu.tw>
AuthorDate: Tue Feb 15 10:27:29 2022 +0800
SUBMARINE-1182. Use the unique source for model management and use it for later serving
### What is this PR for?
1. Use the unique source of the model to make the following serving work
2. Provide create model version REST API
3. Fix bugs
### What type of PR is it?
[Bug Fix | Feature]
### Todos
* [x] - Provide unique model source.
* [x] - Serving a model created by experiment.
* [ ] - Update document.
* [ ] - Provide tests.
### What is the Jira issue?
https://issues.apache.org/jira/browse/SUBMARINE-1182
### How should this be tested?
### Screenshots (if appropriate)
1. PyTorch example
https://user-images.githubusercontent.com/38066413/151348577-6a6c14a3-3404-43e7-92b1-f0a7a63bb6fe.mp4
2. TensorFlow example
https://user-images.githubusercontent.com/38066413/151350155-669d8442-63f3-4349-90f9-66c2f6ae6c35.mp4
### Questions:
* Do the license files need updating? No
* Are there breaking changes for older versions? No
* Does this need new documentation? Yes
Author: KUAN-HSUN-LI <b0...@ntu.edu.tw>
Signed-off-by: Kevin <pi...@apache.org>
Closes #876 from KUAN-HSUN-LI/SUBMARINE-1182 and squashes the following commits:
b2c2b949 [KUAN-HSUN-LI] fix
598a1f2a [KUAN-HSUN-LI] SUBMARINE-1182. Complete the manage source and serving source
04f23f9b [KUAN-HSUN-LI] SUBMARINE-1182. Complete the manage source and serving source
2abb8516 [KUAN-HSUN-LI] SUBMARINE-1182. Complete the manage source and serving source
101dd899 [KUAN-HSUN-LI] SUBMARINE-1182. add id in spec
5247206d [KUAN-HSUN-LI] SUBMARINE-1182. Change log artifact path in Python SDK
6be64872 [KUAN-HSUN-LI] SUBMARINE-1182. Remove model version source
---
dev-support/database/submarine-model.sql | 8 +-
.../MirroredStrategy/mnist_keras_distributed.py | 4 +-
dev-support/examples/nn-pytorch/model.py | 5 +-
.../pysubmarine/submarine/artifacts/repository.py | 24 ++---
.../pysubmarine/submarine/client/__init__.py | 3 +
.../pytorch.py => client/utils/__init__.py} | 19 ++--
.../entities/model_registry/model_version.py | 35 +++----
.../pysubmarine/submarine/models/pytorch.py | 2 +-
.../pysubmarine/submarine/store/database/models.py | 29 +++---
.../store/model_registry/abstract_store.py | 4 +-
.../store/model_registry/sqlalchemy_store.py | 10 +-
.../pysubmarine/submarine/tracking/client.py | 103 +++++++++++++++------
.../pysubmarine/submarine/tracking/fluent.py | 9 +-
.../pysubmarine/submarine/tracking/utils.py | 4 +
.../entities/model_registry/test_model_version.py | 14 +--
.../store/model_registry/test_sqlalchemy_store.py | 77 +++++++--------
.../pysubmarine/tests/tracking/test_tracking.py | 17 ++--
.../submarine/server/api/model/ServeSpec.java | 9 ++
.../submarine/server/model/ModelManager.java | 39 ++++++--
.../database/entities/ModelVersionEntity.java | 14 +--
.../submarine/server/rest/ExperimentRestApi.java | 2 +-
.../submarine/server/rest/ModelVersionRestApi.java | 65 ++++++++++++-
.../submarine/server/rest/RestConstants.java | 2 +-
.../apache/submarine/server/rest/ServeRestApi.java | 14 +--
.../org/apache/submarine/server/s3/Client.java | 81 ++++++++++++----
.../database/mappers/ModelVersionMapper.xml | 14 +--
.../server/model/database/ModelVersionTagTest.java | 2 +-
.../server/model/database/ModelVersionTest.java | 12 +--
.../server/rest/ModelVersionRestApiTest.java | 30 +++---
.../org/apache/submarine/server/s3/ClientTest.java | 37 +++++---
.../server/submitter/k8s/K8SJobSubmitterTest.java | 1 -
31 files changed, 440 insertions(+), 249 deletions(-)
diff --git a/dev-support/database/submarine-model.sql b/dev-support/database/submarine-model.sql
index 3ace957..0726323 100644
--- a/dev-support/database/submarine-model.sql
+++ b/dev-support/database/submarine-model.sql
@@ -33,7 +33,7 @@ DROP TABLE IF EXISTS `model_version`;
CREATE TABLE `model_version` (
`name` VARCHAR(256) NOT NULL COMMENT 'Name of model',
`version` INTEGER NOT NULL,
- `source` VARCHAR(512) NOT NULL COMMENT 'Model saved link',
+ `id` VARCHAR(64) NOT NULL COMMENT 'Id of the model',
`user_id` VARCHAR(64) NOT NULL COMMENT 'Id of the created user',
`experiment_id` VARCHAR(64) NOT NULL,
`model_type` VARCHAR(64) NOT NULL COMMENT 'Type of model',
@@ -43,8 +43,8 @@ CREATE TABLE `model_version` (
`dataset` VARCHAR(256) COMMENT 'Which dataset is used',
`description` VARCHAR(5000),
CONSTRAINT `model_version_pk` PRIMARY KEY (`name`, `version`),
- FOREIGN KEY(`name`) REFERENCES `registered_model` (`name`) ON UPDATE CASCADE ON DELETE CASCADE,
- UNIQUE(`source`)
+ UNIQUE (`name`, `id`),
+ FOREIGN KEY(`name`) REFERENCES `registered_model` (`name`) ON UPDATE CASCADE ON DELETE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
DROP TABLE IF EXISTS `model_version_tag`;
@@ -62,7 +62,7 @@ CREATE TABLE `metric` (
`key` VARCHAR(190) NOT NULL COMMENT 'Metric key: `String` (limit 190 characters). Part of *Primary Key* for ``metric`` table.',
`value` FLOAT NOT NULL COMMENT 'Metric value: `Float`. Defined as *Non-null* in schema.',
`worker_index` VARCHAR(32) NOT NULL COMMENT 'Metric worker_index: `String` (limit 32 characters). Part of *Primary Key* for\r\n ``metrics`` table.',
- `timestamp` DATETIME(3) NOT NULL COMMENT 'Timestamp recorded for this metric entry: `DATETIME` (millisecond precision).
+ `timestamp` DATETIME(3) NOT NULL COMMENT 'Timestamp recorded for this metric entry: `DATETIME` (millisecond precision).
Part of *Primary Key* for ``metrics`` table.',
`step` INTEGER NOT NULL COMMENT 'Step recorded for this metric entry: `INTEGER`.',
`is_nan` BOOLEAN NOT NULL COMMENT 'True if the value is in fact NaN.',
diff --git a/dev-support/examples/mnist-tensorflow/MirroredStrategy/mnist_keras_distributed.py b/dev-support/examples/mnist-tensorflow/MirroredStrategy/mnist_keras_distributed.py
index eabf9bd..943e9e2 100644
--- a/dev-support/examples/mnist-tensorflow/MirroredStrategy/mnist_keras_distributed.py
+++ b/dev-support/examples/mnist-tensorflow/MirroredStrategy/mnist_keras_distributed.py
@@ -90,6 +90,7 @@ class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print("\nLearning rate for epoch {} is {}".format(epoch + 1, model.optimizer.lr.numpy()))
submarine.log_metric("lr", model.optimizer.lr.numpy())
+ submarine.save_model(model, "tensorflow", "mnist-tf")
# Put all the callbacks together.
@@ -101,7 +102,7 @@ callbacks = [
]
if __name__ == "__main__":
- EPOCHS = 5
+ EPOCHS = 2
hist = model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks)
for i in range(EPOCHS):
submarine.log_metric("val_loss", hist.history["loss"][i], i)
@@ -111,7 +112,6 @@ if __name__ == "__main__":
print("Eval loss: {}, Eval accuracy: {}".format(eval_loss, eval_acc))
submarine.log_param("loss", eval_loss)
submarine.log_param("acc", eval_acc)
-
"""Reference:
https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy
"""
diff --git a/dev-support/examples/nn-pytorch/model.py b/dev-support/examples/nn-pytorch/model.py
index f236281..c591949 100644
--- a/dev-support/examples/nn-pytorch/model.py
+++ b/dev-support/examples/nn-pytorch/model.py
@@ -32,10 +32,9 @@ class LinearNNModel(torch.nn.Module):
if __name__ == "__main__":
net = LinearNNModel()
submarine.save_model(
- model_type="pytorch",
model=net,
- artifact_path="pytorch-nn-model",
- registered_model_name="simple-nn-model",
+ model_type="pytorch",
+ registered_model_name="simple-pytorch-model",
input_dim=[2],
output_dim=[1],
)
diff --git a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
index 30622c2..50e469c 100644
--- a/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
+++ b/submarine-sdk/pysubmarine/submarine/artifacts/repository.py
@@ -19,30 +19,27 @@ import boto3
class Repository:
- def __init__(self, experiment_id: str):
+ def __init__(self):
self.client = boto3.client(
"s3",
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
endpoint_url=os.environ.get("MLFLOW_S3_ENDPOINT_URL"),
)
- self.dest_path = experiment_id
self.bucket = "submarine"
def _upload_file(self, local_file: str, bucket: str, key: str) -> None:
self.client.upload_file(Filename=local_file, Bucket=bucket, Key=key)
- def _list_artifact_subfolder(self, artifact_path: str):
+ def list_artifact_subfolder(self, dest_path):
response = self.client.list_objects(
Bucket=self.bucket,
- Prefix=os.path.join(self.dest_path, artifact_path) + "/",
+ Prefix=f"{dest_path}/",
Delimiter="/",
)
return response.get("CommonPrefixes")
- def log_artifact(self, local_file: str, artifact_path: str) -> None:
- dest_path = self.dest_path
- dest_path = os.path.join(dest_path, artifact_path)
+ def log_artifact(self, dest_path: str, local_file: str) -> None:
dest_path = os.path.join(dest_path, os.path.basename(local_file))
self._upload_file(
local_file=local_file,
@@ -50,14 +47,7 @@ class Repository:
key=dest_path,
)
- def log_artifacts(self, local_dir: str, artifact_path: str) -> str:
- dest_path = self.dest_path
- list_of_subfolder = self._list_artifact_subfolder(artifact_path)
- if list_of_subfolder is None:
- artifact_path = os.path.join(artifact_path, "1")
- else:
- artifact_path = os.path.join(artifact_path, str(len(list_of_subfolder) + 1))
- dest_path = os.path.join(dest_path, artifact_path)
+ def log_artifacts(self, dest_path: str, local_dir: str) -> str:
local_dir = os.path.abspath(local_dir)
for (root, _, filenames) in os.walk(local_dir):
upload_path = dest_path
@@ -72,8 +62,8 @@ class Repository:
)
return f"s3://{self.bucket}/{dest_path}"
- def delete_folder(self) -> None:
- objects_to_delete = self.client.list_objects(Bucket=self.bucket, Prefix=self.dest_path)
+ def delete_folder(self, dest_path) -> None:
+ objects_to_delete = self.client.list_objects(Bucket=self.bucket, Prefix=dest_path)
if objects_to_delete.get("Contents") is not None:
delete_keys: dict = {"Objects": []}
delete_keys["Objects"] = [
diff --git a/submarine-sdk/pysubmarine/submarine/client/__init__.py b/submarine-sdk/pysubmarine/submarine/client/__init__.py
index bec3f2d..1a210eb 100644
--- a/submarine-sdk/pysubmarine/submarine/client/__init__.py
+++ b/submarine-sdk/pysubmarine/submarine/client/__init__.py
@@ -62,3 +62,6 @@ from submarine.client.models.notebook_meta import NotebookMeta
from submarine.client.models.notebook_pod_spec import NotebookPodSpec
from submarine.client.models.notebook_spec import NotebookSpec
from submarine.client.models.serve_spec import ServeSpec
+
+# import utils
+from submarine.client.utils.api_utils import *
diff --git a/submarine-sdk/pysubmarine/submarine/models/pytorch.py b/submarine-sdk/pysubmarine/submarine/client/utils/__init__.py
similarity index 56%
copy from submarine-sdk/pysubmarine/submarine/models/pytorch.py
copy to submarine-sdk/pysubmarine/submarine/client/utils/__init__.py
index 1b45d30..9d94611 100644
--- a/submarine-sdk/pysubmarine/submarine/models/pytorch.py
+++ b/submarine-sdk/pysubmarine/submarine/client/utils/__init__.py
@@ -1,11 +1,11 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
+# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
+# the License. You may obtain a copy of the License at
#
-# http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -13,12 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
+from submarine.client.utils.api_utils import generate_host, get_api_client
-import torch
-
-
-def save_model(model, artifact_path: str, input_dim: list) -> None:
- example_forward_example = torch.rand(input_dim)
- scripted_model = torch.jit.trace(model, example_forward_example)
- scripted_model.save(model, os.path.join(artifact_path, "model.pth"))
+__all__ = [
+ "generate_host",
+ "get_api_client",
+]
diff --git a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
index 4d2f3fb..bbfab64 100644
--- a/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
+++ b/submarine-sdk/pysubmarine/submarine/entities/model_registry/model_version.py
@@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from datetime import datetime
+from typing import Optional
+
from submarine.entities._submarine_object import _SubmarineObject
@@ -23,22 +26,22 @@ class ModelVersion(_SubmarineObject):
def __init__(
self,
- name,
- version,
- source,
- user_id,
- experiment_id,
- model_type,
- current_stage,
- creation_time,
- last_updated_time,
- dataset=None,
- description=None,
- tags=None,
+ name: str,
+ version: int,
+ id: str,
+ user_id: str,
+ experiment_id: str,
+ model_type: str,
+ current_stage: str,
+ creation_time: datetime,
+ last_updated_time: datetime,
+ dataset: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[list] = None,
):
self._name = name
self._version = version
- self._source = source
+ self._id = id
self._user_id = user_id
self._experiment_id = experiment_id
self._model_type = model_type
@@ -60,9 +63,9 @@ class ModelVersion(_SubmarineObject):
return self._version
@property
- def source(self):
- """String. Source path for the model."""
- return self._source
+ def id(self):
+ """String. ID of the model"""
+ return self._id
@property
def user_id(self):
diff --git a/submarine-sdk/pysubmarine/submarine/models/pytorch.py b/submarine-sdk/pysubmarine/submarine/models/pytorch.py
index 1b45d30..1fc7943 100644
--- a/submarine-sdk/pysubmarine/submarine/models/pytorch.py
+++ b/submarine-sdk/pysubmarine/submarine/models/pytorch.py
@@ -21,4 +21,4 @@ import torch
def save_model(model, artifact_path: str, input_dim: list) -> None:
example_forward_example = torch.rand(input_dim)
scripted_model = torch.jit.trace(model, example_forward_example)
- scripted_model.save(model, os.path.join(artifact_path, "model.pth"))
+ scripted_model.save(os.path.join(artifact_path, "model.pt"))
diff --git a/submarine-sdk/pysubmarine/submarine/store/database/models.py b/submarine-sdk/pysubmarine/submarine/store/database/models.py
index 8f2f587..68f592d 100644
--- a/submarine-sdk/pysubmarine/submarine/store/database/models.py
+++ b/submarine-sdk/pysubmarine/submarine/store/database/models.py
@@ -27,6 +27,7 @@ from sqlalchemy import (
PrimaryKeyConstraint,
String,
Text,
+ UniqueConstraint,
)
from sqlalchemy.dialects.mysql import DATETIME
from sqlalchemy.ext.declarative import declarative_base
@@ -154,13 +155,13 @@ class SqlRegisteredModelTag(Base):
return RegisteredModelTag(self.tag)
-# +----------+---------+-------------------------------+-----+
-# | name | version | source | ... |
-# +----------+---------+-------------------------------+-----+
-# | ResNet50 | 1 | s3://submarine/ResNet50/1/ | ... |
-# | ResNet50 | 2 | s3://submarine/ResNet50/2/ | ... |
-# | BERT | 1 | s3://submarine/BERT/1/ | ... |
-# +----------+---------+-------------------------------+-----+
+# +----------+---------+----------------------------------+-----+
+# | name | version | id | ... |
+# +----------+---------+----------------------------------+-----+
+# | ResNet50 | 1 | 4ed6572b74a54020b0987ebf53170940 | ... |
+# | ResNet50 | 2 | 1a67f138c1ff41778edf83451d5fd52f | ... |
+# | BERT | 1 | 42ae7f58ba354872a95f6872e16c3544 | ... |
+# +----------+---------+----------------------------------+-----+
class SqlModelVersion(Base):
@@ -180,10 +181,9 @@ class SqlModelVersion(Base):
Version of registered model: Part of *Primary Key* for ``model_version`` table.
"""
- source = Column(String(512), nullable=False, unique=True)
+ id = Column(String(64), nullable=False)
"""
- Source of model: Part of *Primary Key* for ``model_version`` table.
- database link refer to this version of model.
+ ID of the model.
"""
user_id = Column(String(64), nullable=False)
@@ -239,11 +239,14 @@ class SqlModelVersion(Base):
"SqlRegisteredModel", back_populates="model_versions"
)
- __table_args__ = (PrimaryKeyConstraint("name", "version", "source", name="model_version_pk"),)
+ __table_args__ = (
+ PrimaryKeyConstraint("name", "version", name="model_version_pk"),
+ UniqueConstraint("name", "id"),
+ )
def __repr__(self):
return (
- f"<SqlModelVersion ({self.name}, {self.version}, {self.source}, {self.user_id},"
+ f"<SqlModelVersion ({self.name}, {self.version}, {self.user_id},"
f" {self.experiment_id}, {self.current_stage}, {self.creation_time},"
f" {self.last_updated_time}, {self.dataset}, {self.description})>"
)
@@ -256,7 +259,7 @@ class SqlModelVersion(Base):
return ModelVersion(
name=self.name,
version=self.version,
- source=self.source,
+ id=self.id,
user_id=self.user_id,
experiment_id=self.experiment_id,
model_type=self.model_type,
diff --git a/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py b/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
index b0344f9..2ed3852 100644
--- a/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
+++ b/submarine-sdk/pysubmarine/submarine/store/model_registry/abstract_store.py
@@ -126,7 +126,7 @@ class AbstractStore:
def create_model_version(
self,
name: str,
- source: str,
+ id: str,
user_id: str,
experiment_id: str,
model_type: str,
@@ -137,7 +137,7 @@ class AbstractStore:
"""
Create a new version of the registered model
:param name: Registered model name.
- :param source: Source path where this version of model is stored.
+ :param id: Model ID generated when model is created and stored in the description.json
:param user_id: User ID from server that created this model
:param experiment_id: Experiment ID which this model is created.
:param dataset: Dataset which this version of model is used.
diff --git a/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py b/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
index a16d358..5444d99 100644
--- a/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/submarine/store/model_registry/sqlalchemy_store.py
@@ -340,7 +340,7 @@ class SqlAlchemyStore(AbstractStore):
def create_model_version(
self,
name: str,
- source: str,
+ id: str,
user_id: str,
experiment_id: str,
model_type: str,
@@ -351,7 +351,7 @@ class SqlAlchemyStore(AbstractStore):
"""
Create a new version of the registered model
:param name: Registered model name.
- :param source: Source path where this version of model is stored.
+ :param id: Model ID generated when model is created and stored in the description.json
:param user_id: User ID from server that created this model
:param experiment_id: Experiment ID which this model is created.
:param dataset: Dataset which this version of model is used.
@@ -378,7 +378,7 @@ class SqlAlchemyStore(AbstractStore):
model_version = SqlModelVersion(
name=name,
version=next_version(sql_registered_model),
- source=source,
+ id=id,
user_id=user_id,
experiment_id=experiment_id,
model_type=model_type,
@@ -517,8 +517,8 @@ class SqlAlchemyStore(AbstractStore):
:return: A single URI location.
"""
with self.ManagedSessionMaker() as session:
- sql_model = self._get_sql_model_version(session, name, version)
- return sql_model.to_submarine_entity().source
+ mv = self._get_sql_model_version(session, name, version)
+ return f"s3://submarine/registry/{mv.id}/{mv.name}/{mv.version}"
@classmethod
def _get_sql_model_version_tag(
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/client.py b/submarine-sdk/pysubmarine/submarine/tracking/client.py
index 76f350a..3199632 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/client.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/client.py
@@ -53,11 +53,12 @@ class SubmarineClient(object):
os.environ["MLFLOW_S3_ENDPOINT_URL"] = s3_registry_uri or S3_ENDPOINT_URL
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id or AWS_ACCESS_KEY_ID
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key or AWS_SECRET_ACCESS_KEY
- self.artifact_repo = Repository(utils.get_job_id())
+ self.artifact_repo = Repository()
self.db_uri = db_uri or submarine.get_db_uri()
self.store = utils.get_tracking_sqlalchemy_store(self.db_uri)
self.model_registry = utils.get_model_registry_sqlalchemy_store(self.db_uri)
self.serve_client = ServeClient(host)
+ self.experiment_id = utils.get_job_id()
def log_metric(
self,
@@ -99,33 +100,85 @@ class SubmarineClient(object):
def save_model(
self,
- model_type: str,
model,
- artifact_path: str,
+ model_type: str,
registered_model_name: str = None,
input_dim: list = None,
output_dim: list = None,
) -> None:
"""
- Save a model into the minio pod.
- :param model_type: The type of the model.
+ Save a model into the minio pod or even register a model.
:param model: Model.
- :param artifact_path: Relative path of the artifact in the minio pod.
+ :param model_type: The type of the model.
:param registered_model_name: If not None, register model into the model registry with
this name. If None, the model only be saved in minio pod.
:param input_dim: Save the input dimension of the given model to the description file.
:param output_dim: Save the output dimension of the given model to the description file.
"""
pattern = r"[0-9A-Za-z][0-9A-Za-z-_]*[0-9A-Za-z]|[0-9A-Za-z]"
- if not re.fullmatch(pattern, artifact_path):
+ if registered_model_name and not re.fullmatch(pattern, registered_model_name):
raise Exception(
- "Artifact_path must only contains numbers, characters, hyphen and underscore. "
- "Artifact_path must starts and ends with numbers or characters."
+ "Registered_model_name must only contains numbers, characters, hyphen and"
+ " underscore. Registered_model_name must starts and ends with numbers or"
+ " characters."
+ )
+
+ model_id = utils.generate_model_id()
+
+ dest_path = self._generate_experiment_artifact_path(f"experiment/{self.experiment_id}")
+
+ # log artifact under the experiment directory
+ self._log_artifact(model, dest_path, model_type, model_id, input_dim, output_dim)
+
+ # Register model
+ if registered_model_name is not None:
+ try:
+ self.model_registry.get_registered_model(registered_model_name)
+ except SubmarineException:
+ self.model_registry.create_registered_model(name=registered_model_name)
+
+ mv = self.model_registry.create_model_version(
+ name=registered_model_name,
+ id=model_id,
+ user_id="", # TODO(jeff-901): the user id is needed to be specified.
+ experiment_id=self.experiment_id,
+ model_type=model_type,
+ )
+
+ # log artifact under the registry directory
+ self._log_artifact(
+ model,
+ f"registry/{mv.name}-{mv.version}-{model_id}/{mv.name}/{mv.version}",
+ model_type,
+ model_id,
+ input_dim,
+ output_dim,
)
+
+ def _log_artifact(
+ self,
+ model,
+ dest_path: str,
+ model_type: str,
+ model_id: str,
+ input_dim: list = None,
+ output_dim: list = None,
+ registered: bool = False,
+ ):
+ """
+ Save a model into the minio pod.
+ :param model: Model.
+ :param dest_path: Destination path of the submarine bucket in the minio pod.
+ :param model_type: The type of the model.
+ :param model_id: ID of the model.
+ :param input_dim: Save the input dimension of the given model to the description file.
+ :param output_dim: Save the output dimension of the given model to the description file.
+ """
with tempfile.TemporaryDirectory() as tempdir:
description: Dict[str, Any] = dict()
- model_save_dir = os.path.join(tempdir, "1")
- os.mkdir(model_save_dir)
+ model_save_dir = tempdir
+ if not os.path.exists(model_save_dir):
+ os.mkdir(model_save_dir)
if model_type == "pytorch":
import submarine.models.pytorch
@@ -143,6 +196,7 @@ class SubmarineClient(object):
raise Exception("No valid type of model has been matched to {}".format(model_type))
# Write description file
+ description["id"] = model_id
if input_dim is not None:
description["input"] = [
{
@@ -156,25 +210,22 @@ class SubmarineClient(object):
}
]
description["model_type"] = model_type
- with open(os.path.join(tempdir, "description.json"), "w") as f:
+ with open(os.path.join(model_save_dir, "description.json"), "w") as f:
json.dump(description, f)
# Log all files into minio
- source = self.artifact_repo.log_artifacts(tempdir, artifact_path)
+ self.artifact_repo.log_artifacts(dest_path, model_save_dir)
- # Register model
- if registered_model_name is not None:
- try:
- self.model_registry.get_registered_model(registered_model_name)
- except SubmarineException:
- self.model_registry.create_registered_model(name=registered_model_name)
- self.model_registry.create_model_version(
- name=registered_model_name,
- source=source,
- user_id="", # TODO(jeff-901): the user id is needed to be specified.
- experiment_id=utils.get_job_id(),
- model_type=model_type,
- )
+ def _generate_experiment_artifact_path(self, dest_path: str) -> str:
+ """
+ :param dest_path: destination of current experiment directory
+ """
+ list_of_subfolder = self.artifact_repo.list_artifact_subfolder(dest_path)
+ return (
+ os.path.join(dest_path, str(len(list_of_subfolder) + 1))
+ if list_of_subfolder
+ else os.path.join(dest_path, "1")
+ )
def create_serve(self, model_name: str, model_version: int, async_req: bool = True):
"""
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/fluent.py b/submarine-sdk/pysubmarine/submarine/tracking/fluent.py
index ec9ed9e..a2c8b65 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/fluent.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/fluent.py
@@ -55,9 +55,8 @@ def log_metric(key, value, step=None):
def save_model(
- model_type: str,
model,
- artifact_path: str,
+ model_type: str,
registered_model_name: str = None,
input_dim: list = None,
output_dim: list = None,
@@ -66,14 +65,10 @@ def save_model(
Save a model into the minio pod.
:param model_type: The type of the model.
:param model: Model.
- :param artifact_path: Relative path of the artifact in the minio pod.
:param registered_model_name: If none None, register model into the model registry with
this name. If None, the model only be saved in minio pod.
"""
- SubmarineClient().save_model(
- model_type, model, artifact_path, registered_model_name, input_dim, output_dim
- )
- SubmarineClient().save_model(model_type, model, artifact_path, registered_model_name)
+ SubmarineClient().save_model(model, model_type, registered_model_name, input_dim, output_dim)
def create_serve(model_name: str, model_version: int):
diff --git a/submarine-sdk/pysubmarine/submarine/tracking/utils.py b/submarine-sdk/pysubmarine/submarine/tracking/utils.py
index 4a223e1..7255835 100644
--- a/submarine-sdk/pysubmarine/submarine/tracking/utils.py
+++ b/submarine-sdk/pysubmarine/submarine/tracking/utils.py
@@ -97,3 +97,7 @@ def get_model_registry_sqlalchemy_store(store_uri: str):
from submarine.store.model_registry.sqlalchemy_store import SqlAlchemyStore
return SqlAlchemyStore(store_uri)
+
+
+def generate_model_id() -> str:
+ return uuid.uuid4().hex
diff --git a/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py b/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py
index 1379c08..c730164 100644
--- a/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py
+++ b/submarine-sdk/pysubmarine/tests/entities/model_registry/test_model_version.py
@@ -23,7 +23,7 @@ class TestModelVersion:
default_data = {
"name": "test",
"version": 1,
- "source": "path/to/source",
+ "id": "1f94b4fadbe144ea8ced0ce195855cfc",
"user_id": "admin",
"experiment_id": "experiment_1",
"model_type": "tensorflow",
@@ -40,7 +40,7 @@ class TestModelVersion:
model_metadata,
name,
version,
- source,
+ id,
user_id,
experiment_id,
model_type,
@@ -54,7 +54,7 @@ class TestModelVersion:
isinstance(model_metadata, ModelVersion)
assert model_metadata.name == name
assert model_metadata.version == version
- assert model_metadata.source == source
+ assert model_metadata.id == id
assert model_metadata.user_id == user_id
assert model_metadata.experiment_id == experiment_id
assert model_metadata.model_type == model_type
@@ -69,7 +69,7 @@ class TestModelVersion:
mv = ModelVersion(
self.default_data["name"],
self.default_data["version"],
- self.default_data["source"],
+ self.default_data["id"],
self.default_data["user_id"],
self.default_data["experiment_id"],
self.default_data["model_type"],
@@ -84,7 +84,7 @@ class TestModelVersion:
mv,
self.default_data["name"],
self.default_data["version"],
- self.default_data["source"],
+ self.default_data["id"],
self.default_data["user_id"],
self.default_data["experiment_id"],
self.default_data["model_type"],
@@ -103,7 +103,7 @@ class TestModelVersion:
mv = ModelVersion(
self.default_data["name"],
self.default_data["version"],
- self.default_data["source"],
+ self.default_data["id"],
self.default_data["user_id"],
self.default_data["experiment_id"],
self.default_data["model_type"],
@@ -118,7 +118,7 @@ class TestModelVersion:
mv,
self.default_data["name"],
self.default_data["version"],
- self.default_data["source"],
+ self.default_data["id"],
self.default_data["user_id"],
self.default_data["experiment_id"],
self.default_data["model_type"],
diff --git a/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py b/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py
index afe3217..2a80bb9 100644
--- a/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py
+++ b/submarine-sdk/pysubmarine/tests/store/model_registry/test_sqlalchemy_store.py
@@ -108,10 +108,10 @@ class TestSqlAlchemyStore(unittest.TestCase):
new_name = "test_rename_RN_new"
rm = self.store.create_registered_model(name)
self.store.create_model_version(
- name, "path/to/source1", "test", "application_1234", "tensorflow"
+ name, "model_id_0", "test", "application_1234", "tensorflow"
)
self.store.create_model_version(
- name, "path/to/source2", "test", "application_1235", "tensorflow"
+ name, "model_id_1", "test", "application_1235", "tensorflow"
)
mv1d = self.store.get_model_version(name, 1)
mv2d = self.store.get_model_version(name, 2)
@@ -146,10 +146,10 @@ class TestSqlAlchemyStore(unittest.TestCase):
rm2 = self.store.create_registered_model(name2, tags=rm_tags)
mv_tags = ["mv_tag1", "mv_tag2"]
rm1mv1 = self.store.create_model_version(
- rm1.name, "path/to/source1", "test", "application_1234", "tensorflow", tags=mv_tags
+ rm1.name, "model_id_0", "test", "application_1234", "tensorflow", tags=mv_tags
)
rm2mv1 = self.store.create_model_version(
- rm2.name, "path/to/source2", "test", "application_1234", "tensorflow", tags=mv_tags
+ rm2.name, "model_id_1", "test", "application_1234", "tensorflow", tags=mv_tags
)
# check store
@@ -384,7 +384,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
self.store.create_registered_model(model_name)
fake_datetime = datetime.now()
mv1 = self.store.create_model_version(
- model_name, "path/to/source1", "test", "application_1234", "tensorflow"
+ model_name, "model_id_0", "test", "application_1234", "tensorflow"
)
self.assertEqual(mv1.name, model_name)
self.assertEqual(mv1.version, 1)
@@ -392,18 +392,18 @@ class TestSqlAlchemyStore(unittest.TestCase):
m1d = self.store.get_model_version(mv1.name, mv1.version)
self.assertEqual(m1d.name, model_name)
+ self.assertEqual(m1d.id, "model_id_0")
self.assertEqual(m1d.user_id, "test")
self.assertEqual(m1d.experiment_id, "application_1234")
self.assertEqual(m1d.model_type, "tensorflow")
self.assertEqual(m1d.current_stage, STAGE_NONE)
self.assertEqual(m1d.creation_time, fake_datetime)
self.assertEqual(m1d.last_updated_time, fake_datetime)
- self.assertEqual(m1d.source, "path/to/source1")
self.assertEqual(m1d.dataset, None)
# new model for same registered model autoincrement version
m2 = self.store.create_model_version(
- model_name, "path/to/source2", "test", "application_1234", "tensorflow"
+ model_name, "model_id_1", "test", "application_1234", "tensorflow"
)
m2d = self.store.get_model_version(m2.name, m2.version)
self.assertEqual(m2.version, 2)
@@ -412,7 +412,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
# create model with tags
tags = ["tag1", "tag2"]
m3 = self.store.create_model_version(
- model_name, "path/to/source3", "test", "application_1234", "tensorflow", tags=tags
+ model_name, "model_id_2", "test", "application_1234", "tensorflow", tags=tags
)
m3d = self.store.get_model_version(m3.name, m3.version)
self.assertEqual(m3.version, 3)
@@ -424,7 +424,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
description = "A test description."
m4 = self.store.create_model_version(
model_name,
- "path/to/source4",
+ "model_id_3",
"test",
"application_1234",
"tensorflow",
@@ -440,11 +440,12 @@ class TestSqlAlchemyStore(unittest.TestCase):
name = "test_update_MV_description"
self.store.create_registered_model(name)
mv1 = self.store.create_model_version(
- name, "path/to/source", "test", "application_1234", "tensorflow"
+ name, "model_id_0", "test", "application_1234", "tensorflow"
)
m1d = self.store.get_model_version(mv1.name, mv1.version)
self.assertEqual(m1d.name, name)
self.assertEqual(m1d.version, 1)
+ self.assertEqual(m1d.id, "model_id_0")
self.assertEqual(m1d.description, None)
# update description
@@ -454,6 +455,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
m1d = self.store.get_model_version(mv1.name, mv1.version)
self.assertEqual(m1d.name, name)
self.assertEqual(m1d.version, 1)
+ self.assertEqual(m1d.id, "model_id_0")
self.assertEqual(m1d.description, "New description.")
self.assertEqual(m1d.last_updated_time, fake_datetime)
@@ -461,10 +463,10 @@ class TestSqlAlchemyStore(unittest.TestCase):
name = "test_transition_MV_stage"
self.store.create_registered_model(name)
mv1 = self.store.create_model_version(
- name, "path/to/source1", "test", "application_1234", "tensorflow"
+ name, "model_id_0", "test", "application_1234", "tensorflow"
)
- m2 = self.store.create_model_version(
- name, "path/to/source2", "test", "application_1234", "tensorflow"
+ mv2 = self.store.create_model_version(
+ name, "model_id_1", "test", "application_1234", "tensorflow"
)
fake_datetime = datetime.strptime("2021-11-11 11:11:11.111000", "%Y-%m-%d %H:%M:%S.%f")
@@ -516,7 +518,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
self.store.transition_model_version_stage(mv1.name, mv1.version, "stage")
# No change for other model
- m2d = self.store.get_model_version(m2.name, m2.version)
+ m2d = self.store.get_model_version(mv2.name, mv2.version)
self.assertEqual(m2d.current_stage, STAGE_NONE)
def test_delete_model_version(self):
@@ -524,7 +526,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
tags = ["tag1", "tag2"]
self.store.create_registered_model(name)
mv = self.store.create_model_version(
- name, "path/to/source", "test", "application_1234", "tensorflow", tags=tags
+ name, "model_id_0", "test", "application_1234", "tensorflow", tags=tags
)
mvd = self.store.get_model_version(mv.name, mv.version)
self.assertEqual(mvd.name, name)
@@ -560,24 +562,19 @@ class TestSqlAlchemyStore(unittest.TestCase):
self.store.create_registered_model(name)
fake_datetime = datetime.now()
mv = self.store.create_model_version(
- name,
- source="path/to/source",
- user_id="test",
- experiment_id="application_1234",
- model_type="tensorflow",
- tags=tags,
+ name, "model_id_0", "test", "application_1234", "tensorflow", tags=tags
)
self.assertEqual(mv.creation_time, fake_datetime)
self.assertEqual(mv.last_updated_time, fake_datetime)
mvd = self.store.get_model_version(mv.name, mv.version)
self.assertEqual(mvd.name, name)
+ self.assertEqual(mvd.id, "model_id_0")
self.assertEqual(mvd.user_id, "test")
self.assertEqual(mvd.experiment_id, "application_1234")
self.assertEqual(mvd.model_type, "tensorflow")
self.assertEqual(mvd.current_stage, STAGE_NONE)
self.assertEqual(mvd.creation_time, fake_datetime)
self.assertEqual(mvd.last_updated_time, fake_datetime)
- self.assertEqual(mvd.source, "path/to/source")
self.assertEqual(mvd.dataset, None)
self.assertEqual(mvd.description, None)
self.assertEqual(mvd.tags, tags)
@@ -597,24 +594,24 @@ class TestSqlAlchemyStore(unittest.TestCase):
tags = ["tag1", "tag2", "tag3"]
models = [
self.store.create_model_version(
- name1, "path/to/source1", "test", "application_1234", "tensorflow"
+ name1, "model_id_0", "test", "application_1234", "tensorflow"
),
self.store.create_model_version(
- name1, "path/to/source2", "test", "application_1234", "tensorflow", tags=[tags[0]]
+ name1, "model_id_1", "test", "application_1234", "tensorflow", tags=[tags[0]]
),
self.store.create_model_version(
- name1, "path/to/source3", "test", "application_1234", "tensorflow", tags=[tags[1]]
+ name1, "model_id_2", "test", "application_1234", "tensorflow", tags=[tags[1]]
),
self.store.create_model_version(
name1,
- "path/to/source4",
+ "model_id_3",
"test",
"application_1234",
"tensorflow",
tags=[tags[0], tags[2]],
),
self.store.create_model_version(
- name1, "path/to/source5", "test", "application_1234", "tensorflow", tags=tags
+ name1, "model_id_4", "test", "application_1234", "tensorflow", tags=tags
),
]
@@ -649,16 +646,10 @@ class TestSqlAlchemyStore(unittest.TestCase):
name = "test_get_model_version_uri"
self.store.create_registered_model(name)
mv = self.store.create_model_version(
- name, "path/to/source", "test", "application_1234", "tensorflow"
+ name, "model_id_0", "test", "application_1234", "tensorflow"
)
uri = self.store.get_model_version_uri(mv.name, mv.version)
- self.assertEqual(uri, "path/to/source")
-
- # uri does not change even if model version is updated
- self.store.transition_model_version_stage(mv.name, mv.version, STAGE_PRODUCTION)
- self.store.update_model_version_description(mv.name, mv.version, "New description.")
- uri = self.store.get_model_version_uri(mv.name, mv.version)
- self.assertEqual(uri, "path/to/source")
+ self.assertEqual(uri, f"s3://submarine/registry/{mv.id}/{mv.name}/{mv.version}")
# cannot retrieve URI for deleted model version
self.store.delete_model_version(mv.name, mv.version)
@@ -672,13 +663,13 @@ class TestSqlAlchemyStore(unittest.TestCase):
self.store.create_registered_model(name1)
self.store.create_registered_model(name2)
rm1mv1 = self.store.create_model_version(
- name1, "path/to/source1", "test", "application_1234", "tensorflow", tags=tags
+ name1, "model_id_0", "test", "application_1234", "tensorflow", tags=tags
)
- rm1m2 = self.store.create_model_version(
- name1, "path/to/source2", "test", "application_1234", "tensorflow", tags=tags
+ rm1mv2 = self.store.create_model_version(
+ name1, "model_id_1", "test", "application_1234", "tensorflow", tags=tags
)
rm2mv1 = self.store.create_model_version(
- name2, "path/to/source3", "test", "application_1234", "tensorflow", tags=tags
+ name2, "model_id_2", "test", "application_1234", "tensorflow", tags=tags
)
new_tag = "new tag"
self.store.add_model_version_tag(rm1mv1.name, rm1mv1.version, new_tag)
@@ -694,7 +685,7 @@ class TestSqlAlchemyStore(unittest.TestCase):
self.assertEqual(mvd.tags, all_tags)
# does not affect other models
- rm1m2d = self.store.get_model_version(rm1m2.name, rm1m2.version)
+ rm1m2d = self.store.get_model_version(rm1mv2.name, rm1mv2.version)
self.assertEqual(rm1m2d.name, name1)
self.assertEqual(rm1m2d.tags, tags)
rm2mv1 = self.store.get_model_version(rm2mv1.name, rm2mv1.version)
@@ -719,13 +710,13 @@ class TestSqlAlchemyStore(unittest.TestCase):
self.store.create_registered_model(name1)
self.store.create_registered_model(name2)
rm1mv1 = self.store.create_model_version(
- name1, "path/to/source1", "test", "application_1234", "tensorflow", tags=tags
+ name1, "model_id_0", "test", "application_1234", "tensorflow", tags=tags
)
rm1m2 = self.store.create_model_version(
- name1, "path/to/source2", "test", "application_1234", "tensorflow", tags=tags
+ name1, "model_id_1", "test", "application_1234", "tensorflow", tags=tags
)
rm2mv1 = self.store.create_model_version(
- name2, "path/to/source3", "test", "application_1234", "tensorflow", tags=tags
+ name2, "model_id_2", "test", "application_1234", "tensorflow", tags=tags
)
new_tag = "new tag"
self.store.add_model_version_tag(rm1mv1.name, rm1mv1.version, new_tag)
diff --git a/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py b/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
index 59feefa..d072b5a 100644
--- a/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
+++ b/submarine-sdk/pysubmarine/tests/tracking/test_tracking.py
@@ -29,6 +29,7 @@ from submarine.tracking.client import SubmarineClient
from .tf_model import LinearNNModel
JOB_ID = "application_123456789"
+REGISTERED_MODEL_NAME = "registerd_model_name"
MLFLOW_S3_ENDPOINT_URL = "http://localhost:9000"
@@ -69,7 +70,8 @@ class TestTracking(unittest.TestCase):
environ["MLFLOW_S3_ENDPOINT_URL"] = MLFLOW_S3_ENDPOINT_URL
environ["AWS_ACCESS_KEY_ID"] = "submarine_minio"
environ["AWS_SECRET_ACCESS_KEY"] = "submarine_minio"
- Repository(JOB_ID).delete_folder()
+ Repository().delete_folder(f"experiment/{JOB_ID}")
+ Repository().delete_folder(f"registry/{REGISTERED_MODEL_NAME}")
def test_log_param(self):
submarine.log_param("name_1", "a")
@@ -97,15 +99,12 @@ class TestTracking(unittest.TestCase):
input_arr = tensorflow.random.uniform((1, 5))
model = LinearNNModel()
model(input_arr)
- registered_model_name = "registerd_model_name"
- self.client.save_model("tensorflow", model, "name_1", registered_model_name)
- self.client.save_model("tensorflow", model, "name_2", registered_model_name)
+ self.client.save_model(model, "tensorflow", "name_1", REGISTERED_MODEL_NAME)
+ self.client.save_model(model, "tensorflow", "name_2", REGISTERED_MODEL_NAME)
# Validate model_versions
- model_versions = self.model_registry.list_model_versions(registered_model_name)
+ model_versions = self.model_registry.list_model_versions(REGISTERED_MODEL_NAME)
assert len(model_versions) == 2
- assert model_versions[0].name == registered_model_name
+ assert model_versions[0].name == REGISTERED_MODEL_NAME
assert model_versions[0].version == 1
- assert model_versions[0].source == f"s3://submarine/{JOB_ID}/name_1/1"
- assert model_versions[1].name == registered_model_name
+ assert model_versions[1].name == REGISTERED_MODEL_NAME
assert model_versions[1].version == 2
- assert model_versions[1].source == f"s3://submarine/{JOB_ID}/name_2/1"
diff --git a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/model/ServeSpec.java b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/model/ServeSpec.java
index 0a0f8fe..5fa11ad 100644
--- a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/model/ServeSpec.java
+++ b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/model/ServeSpec.java
@@ -21,6 +21,7 @@ package org.apache.submarine.server.api.model;
public class ServeSpec {
private String modelName;
private Integer modelVersion;
+ private String modelId;
private String modelType;
private String modelURI;
@@ -40,6 +41,14 @@ public class ServeSpec {
this.modelVersion = modelVersion;
}
+ public String getModelId() {
+ return modelId;
+ }
+
+ public void setModelId(String modelId) {
+ this.modelId = modelId;
+ }
+
public String getModelType() {
return modelType;
}
diff --git a/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/ModelManager.java b/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/ModelManager.java
index e9e1878..c0d01d4 100644
--- a/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/ModelManager.java
+++ b/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/ModelManager.java
@@ -23,7 +23,6 @@ import org.json.JSONArray;
import org.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.nio.file.Paths;
import javax.ws.rs.core.Response;
import org.apache.submarine.commons.utils.exception.SubmarineRuntimeException;
@@ -35,6 +34,7 @@ import org.apache.submarine.server.api.proto.TritonModelConfig;
import org.apache.submarine.server.model.database.entities.ModelVersionEntity;
import org.apache.submarine.server.model.database.service.ModelVersionService;
import org.apache.submarine.server.s3.Client;
+import org.apache.submarine.server.s3.S3Constants;
public class ModelManager {
@@ -100,7 +100,7 @@ public class ModelManager {
} else {
if (spec.getModelName() == null) {
throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
- "Invalid. Model name in Serve Soec is null.");
+ "Invalid. Model name in Serve Spec is null.");
}
Integer modelVersion = spec.getModelVersion();
if (modelVersion == null || modelVersion <= 0) {
@@ -115,14 +115,29 @@ public class ModelManager {
// Get model type and model uri from DB and set the value in the spec.
ModelVersionEntity modelVersion = modelVersionService.select(spec.getModelName(), spec.getModelVersion());
- spec.setModelURI(modelVersion.getSource());
- spec.setModelType(modelVersion.getModelType());
+ String modelType = modelVersion.getModelType();
+ String modelId = modelVersion.getId();
+ spec.setModelType(modelType);
+ spec.setModelId(modelId);
+
+ String modelUniquePath = String.format("%s-%d-%s", spec.getModelName(), spec.getModelVersion(), modelId);
+ if (spec.getModelType().equals("pytorch")) {
+ spec.setModelURI(String.format("s3://%s/registry/%s", S3Constants.BUCKET, modelUniquePath));
+ } else if (spec.getModelType().equals("tensorflow")) {
+ spec.setModelURI(String.format("s3://%s/registry/%s/%s", S3Constants.BUCKET, modelUniquePath,
+ spec.getModelName()));
+ } else {
+ throw new SubmarineRuntimeException(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(),
+ String.format("Unexpected model type: %s", modelType));
+ }
}
private void transferDescription(ServeSpec spec) {
Client s3Client = new Client();
- String res = new String(s3Client.downloadArtifact(
- Paths.get(spec.getModelName(), "description.json").toString()));
+ String modelUniquePath = String.format("%s-%d-%s",
+ spec.getModelName(), spec.getModelVersion(), spec.getModelId());
+ String res = new String(s3Client.downloadArtifact(String.format("registry/%s/%s/%d/description.json",
+ modelUniquePath, spec.getModelName(), spec.getModelVersion())));
JSONObject description = new JSONObject(res);
TritonModelConfig.ModelConfig.Builder modelConfig = TritonModelConfig.ModelConfig.newBuilder();
@@ -148,14 +163,20 @@ public class ModelManager {
modelConfig.addOutput(modelOutput);
}
- s3Client.logArtifact(Paths.get(spec.getModelName(), "config.pbtxt").toString(),
+ s3Client.logArtifact(String.format("registry/%s/%s/config.pbtxt", modelUniquePath, spec.getModelName()),
modelConfig.toString().getBytes());
}
private ServeResponse getServeResponse(ServeSpec spec){
ServeResponse serveResponse = new ServeResponse();
- serveResponse.setUrl(String.format("http://{submarine ip}/%s/%d/api/v1.0/predictions",
- spec.getModelName(), spec.getModelVersion()));
+ if (spec.getModelType().equals("pytorch")) {
+ serveResponse.setUrl(String.format("http://{submarine ip}/%s/%d/v2/models/%s/infer",
+ spec.getModelName(), spec.getModelVersion(), spec.getModelName()));
+ } else {
+ serveResponse.setUrl(String.format("http://{submarine ip}/%s/%d/api/v1.0/predictions",
+ spec.getModelName(), spec.getModelVersion()));
+ }
+
return serveResponse;
}
}
diff --git a/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/entities/ModelVersionEntity.java b/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/entities/ModelVersionEntity.java
index a66b6a3..6e30209 100644
--- a/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/entities/ModelVersionEntity.java
+++ b/submarine-server/server-core/src/main/java/org/apache/submarine/server/model/database/entities/ModelVersionEntity.java
@@ -27,7 +27,7 @@ public class ModelVersionEntity {
private Integer version;
- private String source;
+ private String id;
private String userId;
@@ -63,12 +63,12 @@ public class ModelVersionEntity {
this.version = version;
}
- public String getSource() {
- return source;
+ public String getId() {
+ return id;
}
- public void setSource(String source) {
- this.source = source;
+ public void setId(String id) {
+ this.id = id;
}
public String getUserId() {
@@ -148,8 +148,8 @@ public class ModelVersionEntity {
public String toString() {
return "ModelVersionEntity{" +
"name='" + name + '\'' +
- ",version='" + version + '\'' +
- ", source='" + source + '\'' +
+ ", version='" + version + '\'' +
+ ", id='" + id + '\'' +
", userId='" + userId + '\'' +
", experimentId='" + experimentId + '\'' +
", modelType='" + modelType + '\'' +
diff --git a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ExperimentRestApi.java b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ExperimentRestApi.java
index c0a3fba..d14f8f6 100644
--- a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ExperimentRestApi.java
+++ b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ExperimentRestApi.java
@@ -274,7 +274,7 @@ public class ExperimentRestApi {
@ApiResponse(responseCode = "404", description = "Experiment not found")})
public Response getArtifactPaths(@PathParam(RestConstants.ID) String id) {
try {
- List<String> artifactPaths = minioClient.listArtifactByExperimentId(id);
+ List<String> artifactPaths = minioClient.listArtifact(String.format("experiment/%s", id));
return new JsonResponse.Builder<List<String>>(Response.Status.OK).success(true)
.result(artifactPaths).build();
diff --git a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ModelVersionRestApi.java b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ModelVersionRestApi.java
index 56c405d..da52118 100644
--- a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ModelVersionRestApi.java
+++ b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ModelVersionRestApi.java
@@ -19,6 +19,8 @@
package org.apache.submarine.server.rest;
+import org.json.JSONObject;
+
import java.util.List;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
@@ -37,7 +39,6 @@ import io.swagger.v3.oas.annotations.media.Content;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
-
import org.apache.submarine.commons.utils.exception.SubmarineRuntimeException;
import org.apache.submarine.server.model.database.entities.ModelVersionEntity;
import org.apache.submarine.server.model.database.entities.ModelVersionTagEntity;
@@ -46,6 +47,7 @@ import org.apache.submarine.server.model.database.service.ModelVersionService;
import org.apache.submarine.server.model.database.service.ModelVersionTagService;
import org.apache.submarine.server.response.JsonResponse;
+import org.apache.submarine.server.s3.Client;
/**
* Model version REST API v1.
@@ -60,6 +62,8 @@ public class ModelVersionRestApi {
/* Model version tag service */
private final ModelVersionTagService modelVersionTagService = new ModelVersionTagService();
+ private final Client s3Client = new Client();
+
/**
* Return the Pong message for test the connectivity.
*
@@ -77,6 +81,61 @@ public class ModelVersionRestApi {
}
/**
+ * Create a model version.
+ *
+ * @param entity registered model entity
+ * example: {
+ * "name": "example_name"
+ * "experimentId" : "4d4d02f06f6f437fa29e1ee8a9276d87"
+ * "userId": ""
+ * "description" : "example_description"
+ * "tags": ["123", "456"]
+ * }
+ * @param baseDir artifact base directory
+ * example: "experiment/experiment-1643015349312-0001/1"
+ * @return success message
+ */
+ @POST
+ @Consumes({ RestConstants.MEDIA_TYPE_YAML, MediaType.APPLICATION_JSON })
+ @Operation(summary = "Create a model version instance", tags = { "model-version" }, responses = {
+ @ApiResponse(description = "successful operation",
+ content = @Content(schema = @Schema(implementation = JsonResponse.class)))})
+ public Response createModelVersion(ModelVersionEntity entity,
+ @QueryParam("baseDir") String baseDir) {
+ try {
+ String res = new String(s3Client.downloadArtifact(
+ String.format("%s/description.json", baseDir)));
+ JSONObject description = new JSONObject(res);
+ String modelType = description.get("model_type").toString();
+ String id = description.get("id").toString();
+ entity.setId(id);
+ entity.setModelType(modelType);
+
+ int version = modelVersionService.selectAllVersions(entity.getName()).stream().mapToInt(
+ ModelVersionEntity::getVersion
+ ).max().orElse(1);
+
+ entity.setVersion(version);
+ modelVersionService.insert(entity);
+
+ // the directory of storing a single model must be unique for serving
+ String uniqueModelPath = String.format("%s-%d-%s", entity.getName(), version, id);
+
+ // copy artifacts
+ s3Client.listAllObjects(baseDir).forEach(s -> {
+ String relativePath = s.substring(String.format("%s/", baseDir).length());
+ s3Client.copyArtifact(String.format("registry/%s/%s/%d/%s", uniqueModelPath,
+ entity.getName(), entity.getVersion(), relativePath), s);
+ });
+
+ return new JsonResponse.Builder<String>(Response.Status.OK).success(true)
+ .message("Create a model version instance").build();
+ } catch (SubmarineRuntimeException e) {
+ return parseModelVersionServiceException(e);
+ }
+ }
+
+ /**
* List all model versions under same registered model name.
*
* @param name registered model name
@@ -126,7 +185,7 @@ public class ModelVersionRestApi {
*
* @param name model version's name
* @param version model version's version
- * @return seccess message
+ * @return success message
*/
@DELETE
@Path("/{name}/{version}")
@@ -277,7 +336,7 @@ public class ModelVersionRestApi {
throw new SubmarineRuntimeException(Response.Status.OK.getStatusCode(),
"Invalid. Model version's version is null.");
}
- Integer versionNum;
+ int versionNum;
try {
versionNum = Integer.parseInt(version);
if (versionNum < 1){
diff --git a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/RestConstants.java b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/RestConstants.java
index 16bc29d..6592d77 100644
--- a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/RestConstants.java
+++ b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/RestConstants.java
@@ -93,7 +93,7 @@ public class RestConstants {
* Serve.
*/
public static final String SERVE = "serve";
-
+
/**
* Internal
*/
diff --git a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ServeRestApi.java b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ServeRestApi.java
index 65f9d72..4864a0d 100644
--- a/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ServeRestApi.java
+++ b/submarine-server/server-core/src/main/java/org/apache/submarine/server/rest/ServeRestApi.java
@@ -64,16 +64,18 @@ public class ServeRestApi {
@POST
@Consumes({ RestConstants.MEDIA_TYPE_YAML, MediaType.APPLICATION_JSON })
- @Operation(summary = "Create a serve instance", tags = { "serve" }, responses = {
- @ApiResponse(description = "successful operation",
- content = @Content(schema = @Schema(implementation = JsonResponse.class)))})
+ @Operation(summary = "Create a serve instance",
+ tags = {"serve"},
+ responses = {
+ @ApiResponse(description = "successful operation", content = @Content(
+ schema = @Schema(implementation = JsonResponse.class)))})
public Response createServe(ServeSpec spec) {
try {
ServeResponse serveResponse = modelManager.createServe(spec);
return new JsonResponse.Builder<ServeResponse>(Response.Status.OK).success(true)
.message("Create a serve instance").result(serveResponse).build();
} catch (SubmarineRuntimeException e) {
- return parseModelVersionServiceException(e);
+ return parseServeServiceException(e);
}
}
@@ -88,11 +90,11 @@ public class ServeRestApi {
return new JsonResponse.Builder<String>(Response.Status.OK).success(true)
.message("Delete the model serve instance").build();
} catch (SubmarineRuntimeException e) {
- return parseModelVersionServiceException(e);
+ return parseServeServiceException(e);
}
}
- private Response parseModelVersionServiceException(SubmarineRuntimeException e) {
+ private Response parseServeServiceException(SubmarineRuntimeException e) {
return new JsonResponse.Builder<String>(e.getCode()).message(e.getMessage()).build();
}
}
diff --git a/submarine-server/server-core/src/main/java/org/apache/submarine/server/s3/Client.java b/submarine-server/server-core/src/main/java/org/apache/submarine/server/s3/Client.java
index 7db06c1..32823fe 100644
--- a/submarine-server/server-core/src/main/java/org/apache/submarine/server/s3/Client.java
+++ b/submarine-server/server-core/src/main/java/org/apache/submarine/server/s3/Client.java
@@ -24,8 +24,11 @@ import java.io.InputStream;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
+import java.util.Stack;
import javax.ws.rs.core.Response;
+import io.minio.CopyObjectArgs;
+import io.minio.CopySource;
import io.minio.GetObjectArgs;
import io.minio.ListObjectsArgs;
import io.minio.MinioClient;
@@ -58,27 +61,22 @@ public class Client {
/**
* Get a list of artifact path under the experiment.
*
- * @param experimentId experiment id
+ * @param path path of the artifact directory
* @return a list of artifact path
*/
- public List<String> listArtifactByExperimentId(String experimentId) throws SubmarineRuntimeException {
- Iterable<Result<Item>> artifactNames = minioClient.listObjects(ListObjectsArgs.builder()
- .bucket(S3Constants.BUCKET).prefix(experimentId + "/").delimiter("/").build());
- List<String> response = new ArrayList<>();
- Iterable<Result<Item>> artifacts;
- for (Result<Item> artifactName : artifactNames) {
- try {
- artifacts = minioClient.listObjects(ListObjectsArgs.builder().bucket(S3Constants.BUCKET)
- .prefix(artifactName.get().objectName()).delimiter("/").build());
- for (Result<Item> artifact: artifacts) {
- response.add("s3://" + S3Constants.BUCKET + "/" + artifact.get().objectName());
- }
- } catch (Exception e) {
- throw new SubmarineRuntimeException(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(),
- e.getMessage());
+ public List<String> listArtifact(String path) throws SubmarineRuntimeException {
+ try {
+ Iterable<Result<Item>> artifacts = minioClient.listObjects(ListObjectsArgs.builder()
+ .bucket(S3Constants.BUCKET).prefix(path + "/").delimiter("/").build());
+ List<String> response = new ArrayList<>();
+ for (Result<Item> artifact: artifacts) {
+ response.add("s3://" + S3Constants.BUCKET + "/" + artifact.get().objectName());
}
+ return response;
+ } catch (Exception e) {
+ throw new SubmarineRuntimeException(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(),
+ e.getMessage());
}
- return response;
}
/**
@@ -129,6 +127,29 @@ public class Client {
return buffer;
}
+
+ /**
+ * Copy an artifact.
+ *
+ * @param targetPath path of the target file
+ * @param sourcePath path of the source file
+ */
+ public void copyArtifact(String targetPath, String sourcePath) {
+ try {
+ minioClient.copyObject(CopyObjectArgs.builder()
+ .bucket(S3Constants.BUCKET)
+ .object(targetPath)
+ .source(CopySource.builder()
+ .bucket(S3Constants.BUCKET)
+ .object(sourcePath)
+ .build())
+ .build());
+ } catch (Exception e) {
+ throw new SubmarineRuntimeException(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(),
+ e.getMessage());
+ }
+ }
+
/**
* Upload an artifact.
*
@@ -147,6 +168,32 @@ public class Client {
}
}
+ public List<String> listAllObjects(String path) throws SubmarineRuntimeException {
+ List<String> result = new ArrayList<>();
+ Stack<String> dirs = new Stack<>();
+ dirs.add(path);
+ while (!dirs.empty()) {
+ String dir = dirs.pop();
+ try {
+ Iterable<Result<Item>> artifacts = minioClient.listObjects(ListObjectsArgs.builder()
+ .bucket(S3Constants.BUCKET).prefix(dir).delimiter("/").build());
+ for (Result<Item> artifact: artifacts) {
+ String objectName = artifact.get().objectName();
+ if (objectName.endsWith("/")) {
+ dirs.add(objectName);
+ } else {
+ result.add(objectName);
+ }
+ }
+ } catch (Exception e) {
+ throw new SubmarineRuntimeException(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(),
+ e.getMessage());
+ }
+ }
+
+ return result;
+ }
+
/**
* Delete all elements under the given folder path.
*/
diff --git a/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ModelVersionMapper.xml b/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ModelVersionMapper.xml
index 2f3b344..ad87879 100644
--- a/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ModelVersionMapper.xml
+++ b/submarine-server/server-core/src/main/resources/org/apache/submarine/database/mappers/ModelVersionMapper.xml
@@ -22,7 +22,7 @@
<resultMap id="resultMap" type="org.apache.submarine.server.model.database.entities.ModelVersionEntity">
<result column="name" property="name" />
<result column="version" property="version" />
- <result column="source" property="source" />
+ <result column="id" property="id" />
<result column="user_id" property="userId" />
<result column="experiment_id" property="experimentId" />
<result column="model_type" property="modelType" />
@@ -36,7 +36,7 @@
<resultMap id="resultMapWithTag" type="org.apache.submarine.server.model.database.entities.ModelVersionEntity">
<result column="name" property="name" />
<result column="version" property="version" />
- <result column="source" property="source" />
+ <result column="id" property="id" />
<result column="user_id" property="userId" />
<result column="experiment_id" property="experimentId" />
<result column="model_type" property="modelType" />
@@ -51,7 +51,7 @@
</resultMap>
<sql id="Base_Column_List">
- name, version, source, user_id, experiment_id, model_type, current_stage, creation_time,
+ name, version, id, user_id, experiment_id, model_type, current_stage, creation_time,
last_updated_time, dataset, description
</sql>
@@ -77,10 +77,10 @@
</select>
<insert id="insert" parameterType="org.apache.submarine.server.model.database.entities.ModelVersionEntity">
- insert into model_version (name, version, source, user_id, experiment_id, model_type, current_stage, creation_time, last_updated_time, dataset, description)
- values (#{name,jdbcType=VARCHAR}, #{version,jdbcType=INTEGER}, #{source,jdbcType=VARCHAR},
- #{userId,jdbcType=VARCHAR}, #{experimentId,jdbcType=VARCHAR}, #{modelType,jdbcType=VARCHAR}, #{currentStage,jdbcType=VARCHAR},
- NOW(3), NOW(3), #{dataset,jdbcType=VARCHAR}, #{description,jdbcType=VARCHAR});
+ insert into model_version (name, version, id, user_id, experiment_id, model_type, current_stage, creation_time, last_updated_time, dataset, description)
+ values (#{name,jdbcType=VARCHAR}, #{version,jdbcType=INTEGER}, #{id,jdbcType=VARCHAR}, #{userId,jdbcType=VARCHAR},
+ #{experimentId,jdbcType=VARCHAR}, #{modelType,jdbcType=VARCHAR}, #{currentStage,jdbcType=VARCHAR},
+ NOW(3), NOW(3), #{dataset,jdbcType=VARCHAR}, #{description,jdbcType=VARCHAR});
<if test="tags != null and !tags.isEmpty()">
insert INTO model_version_tag (name, version, tag) values
<foreach collection="tags" item="tag" index="index" separator=",">
diff --git a/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTagTest.java b/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTagTest.java
index 72e5253..95a32d9 100644
--- a/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTagTest.java
+++ b/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTagTest.java
@@ -54,7 +54,7 @@ public class ModelVersionTagTest {
ModelVersionEntity modelVersionEntity = new ModelVersionEntity();
modelVersionEntity.setName(name);
modelVersionEntity.setVersion(version);
- modelVersionEntity.setSource("path/to/source");
+ modelVersionEntity.setId("model_version_id");
modelVersionEntity.setUserId("test");
modelVersionEntity.setExperimentId("application_1234");
modelVersionEntity.setModelType("tensorflow");
diff --git a/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTest.java b/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTest.java
index 3c089b2..72bf72b 100644
--- a/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTest.java
+++ b/submarine-server/server-core/src/test/java/org/apache/submarine/server/model/database/ModelVersionTest.java
@@ -52,7 +52,7 @@ public class ModelVersionTest {
ModelVersionEntity modelVersionEntity = new ModelVersionEntity();
modelVersionEntity.setName(name);
modelVersionEntity.setVersion(version);
- modelVersionEntity.setSource("path/to/source");
+ modelVersionEntity.setId("model_version_id");
modelVersionEntity.setUserId("test");
modelVersionEntity.setExperimentId("application_1234");
modelVersionEntity.setModelType("tensorflow");
@@ -65,7 +65,7 @@ public class ModelVersionTest {
ModelVersionEntity modelVersionEntity2 = new ModelVersionEntity();
modelVersionEntity2.setName(name);
modelVersionEntity2.setVersion(version2);
- modelVersionEntity2.setSource("path/to/source2");
+ modelVersionEntity2.setId("model_version_id2");
modelVersionEntity2.setUserId("test");
modelVersionEntity2.setExperimentId("application_1234");
modelVersionEntity2.setModelType("tensorflow");
@@ -92,7 +92,7 @@ public class ModelVersionTest {
ModelVersionEntity modelVersionEntity = new ModelVersionEntity();
modelVersionEntity.setName(name);
modelVersionEntity.setVersion(version);
- modelVersionEntity.setSource("path/to/source");
+ modelVersionEntity.setId("model_version_id");
modelVersionEntity.setUserId("test");
modelVersionEntity.setExperimentId("application_1234");
modelVersionEntity.setModelType("tensorflow");
@@ -118,7 +118,7 @@ public class ModelVersionTest {
ModelVersionEntity modelVersionEntity = new ModelVersionEntity();
modelVersionEntity.setName(name);
modelVersionEntity.setVersion(version);
- modelVersionEntity.setSource("path/to/source");
+ modelVersionEntity.setId("model_version_id");
modelVersionEntity.setUserId("test");
modelVersionEntity.setExperimentId("application_1234");
modelVersionEntity.setModelType("tensorflow");
@@ -150,7 +150,7 @@ public class ModelVersionTest {
ModelVersionEntity modelVersionEntity = new ModelVersionEntity();
modelVersionEntity.setName(name);
modelVersionEntity.setVersion(version);
- modelVersionEntity.setSource("path/to/source");
+ modelVersionEntity.setId("model_version_id");
modelVersionEntity.setUserId("test");
modelVersionEntity.setExperimentId("application_1234");
modelVersionEntity.setModelType("tensorflow");
@@ -163,7 +163,7 @@ public class ModelVersionTest {
private void compareModelVersion(ModelVersionEntity expected, ModelVersionEntity actual) {
Assert.assertEquals(expected.getName(), actual.getName());
Assert.assertEquals(expected.getVersion(), actual.getVersion());
- Assert.assertEquals(expected.getSource(), actual.getSource());
+ Assert.assertEquals(expected.getId(), actual.getId());
Assert.assertEquals(expected.getUserId(), actual.getUserId());
Assert.assertEquals(expected.getExperimentId(), actual.getExperimentId());
Assert.assertEquals(expected.getModelType(), actual.getModelType());
diff --git a/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java b/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java
index ca2e770..d5b01ee 100644
--- a/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java
+++ b/submarine-server/server-core/src/test/java/org/apache/submarine/server/rest/ModelVersionRestApiTest.java
@@ -47,8 +47,9 @@ public class ModelVersionRestApiTest {
private final String registeredModelDescription = "test registered model description";
private final String modelVersionDescription = "test model version description";
private final String newModelVersionDescription = "new test registered model description";
- private final String modelVersionSource = "s3://submarine/test";
- private final String modelVersionUid = "test123";
+ private final String modelVersionId = "model_version_id";
+ private final String modelVersionId2 = "model_version_id2";
+ private final String modelVersionUserId = "test123";
private final String modelVersionExperimentId = "experiment_123";
private final String modelVersionModelType = "experiment_123";
private final String modelVersionTag = "testTag";
@@ -74,23 +75,29 @@ public class ModelVersionRestApiTest {
registeredModel.setDescription(registeredModelDescription);
registeredModelService.insert(registeredModel);
modelVersion1.setName(registeredModelName);
- modelVersion1.setDescription(modelVersionDescription + "1");
+ modelVersion1.setDescription(String.format("%s1", modelVersionDescription));
modelVersion1.setVersion(1);
- modelVersion1.setSource(modelVersionSource + "1");
- modelVersion1.setUserId(modelVersionUid);
+ modelVersion1.setId(modelVersionId);
+ modelVersion1.setUserId(modelVersionUserId);
modelVersion1.setExperimentId(modelVersionExperimentId);
modelVersion1.setModelType(modelVersionModelType);
modelVersionService.insert(modelVersion1);
+
modelVersion2.setName(registeredModelName);
- modelVersion2.setDescription(modelVersionDescription + "2");
+ modelVersion2.setDescription(String.format("%s2", modelVersionDescription));
modelVersion2.setVersion(2);
- modelVersion2.setSource(modelVersionSource + "2");
- modelVersion2.setUserId(modelVersionUid);
+ modelVersion2.setId(modelVersionId2);
+ modelVersion2.setUserId(modelVersionUserId);
modelVersion2.setExperimentId(modelVersionExperimentId);
modelVersion2.setModelType(modelVersionModelType);
modelVersionService.insert(modelVersion2);
}
+ @After
+ public void tearDown(){
+ registeredModelService.deleteAll();
+ }
+
@Test
public void testListModelVersion(){
Response listModelVersionResponse = modelVersionRestApi.listModelVersions(registeredModelName);
@@ -147,10 +154,7 @@ public class ModelVersionRestApiTest {
verifyResult(modelVersion2, result.get(0));
}
- @After
- public void tearDown(){
- registeredModelService.deleteAll();
- }
+
private <T> T getResultFromResponse(Response response, Class<T> typeT) {
String entity = (String) response.getEntity();
@@ -175,7 +179,7 @@ public class ModelVersionRestApiTest {
assertEquals(result.getName(), actual.getName());
assertEquals(result.getDescription(), actual.getDescription());
assertEquals(result.getVersion(), actual.getVersion());
- assertEquals(result.getSource(), actual.getSource());
+ assertEquals(result.getId(), actual.getId());
assertEquals(result.getExperimentId(), actual.getExperimentId());
assertEquals(result.getModelType(), actual.getModelType());
}
diff --git a/submarine-server/server-core/src/test/java/org/apache/submarine/server/s3/ClientTest.java b/submarine-server/server-core/src/test/java/org/apache/submarine/server/s3/ClientTest.java
index ff6ed97..7f86e3a 100644
--- a/submarine-server/server-core/src/test/java/org/apache/submarine/server/s3/ClientTest.java
+++ b/submarine-server/server-core/src/test/java/org/apache/submarine/server/s3/ClientTest.java
@@ -22,13 +22,13 @@ package org.apache.submarine.server.s3;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
+
import java.util.List;
public class ClientTest {
- private Client client = new Client("http://localhost:9000");
+ private final Client client = new Client("http://localhost:9000");
private final String testExperimentId = "experiment-sample";
- private final String bucket = "s3://submarine";
@After
public void cleanAll() {
@@ -38,30 +38,45 @@ public class ClientTest {
@Test
public void testLogArtifactAndDownloadArtifact() {
String path = "sample_folder/sample_file";
- byte[] content = "0123456789".getBytes();;
+ byte[] content = "0123456789".getBytes();
client.logArtifact(path, content);
byte[] response = client.downloadArtifact(path);
Assert.assertArrayEquals(content, response);
}
@Test
- public void testListArtifactByExperimentIdAndDeleteArtifactByExperiment() {
- String testModelName = "sample";
+ public void testListAndDeleteArtifactByExperimentId() {
byte[] content = "0123456789".getBytes();
String[] artifactPaths = {
- testExperimentId + "/" + testModelName + "/1",
- testExperimentId + "/" + testModelName + "/2"};
+ String.format("experiment/%s/1", testExperimentId),
+ String.format("experiment/%s/2", testExperimentId)
+ };
String[] actualResults = {
- bucket + "/" + testExperimentId + "/" + testModelName + "/1",
- bucket + "/" + testExperimentId + "/" + testModelName + "/2"};
+ String.format("s3://%s/experiment/%s/1", S3Constants.BUCKET, testExperimentId),
+ String.format("s3://%s/experiment/%s/2", S3Constants.BUCKET, testExperimentId)
+ };
client.logArtifact(artifactPaths[0], content);
client.logArtifact(artifactPaths[1], content);
- List<String> results = client.listArtifactByExperimentId(testExperimentId);
+ List<String> results = client.listArtifact(String.format("experiment/%s", testExperimentId));
Assert.assertArrayEquals(actualResults, results.toArray());
client.deleteArtifactsByExperiment(testExperimentId);
- results = client.listArtifactByExperimentId(testExperimentId);
+ results = client.listArtifact(testExperimentId);
Assert.assertArrayEquals(new String[0], results.toArray());
}
+
+ @Test
+ public void testCopyObject() {
+ String path = "sample_folder/sample_file";
+ byte[] content = "0123456789".getBytes();
+ client.logArtifact(path, content);
+ byte[] response = client.downloadArtifact(path);
+ Assert.assertArrayEquals(content, response);
+
+ String copyPath = "sample_folder_copy/sample_file";
+ client.copyArtifact(copyPath, path);
+ response = client.downloadArtifact(copyPath);
+ Assert.assertArrayEquals(content, response);
+ }
}
diff --git a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/K8SJobSubmitterTest.java b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/K8SJobSubmitterTest.java
index 8e25cc6..eaa8b06 100644
--- a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/K8SJobSubmitterTest.java
+++ b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/K8SJobSubmitterTest.java
@@ -68,7 +68,6 @@ public class K8SJobSubmitterTest extends SpecBuilder {
spec.setModelName("simple");
spec.setModelVersion(1);
spec.setModelType("tensorflow");
- spec.setModelURI("s3://submarine/simple");
submitter.createServe(spec);
}
---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org