You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by wc...@apache.org on 2021/05/10 13:12:39 UTC

[submarine] branch master updated: SUBMARINE-817. Add model management e2e test

This is an automated email from the ASF dual-hosted git repository.

wcjtw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git


The following commit(s) were added to refs/heads/master by this push:
     new 17ffc4e  SUBMARINE-817. Add model management e2e test
17ffc4e is described below

commit 17ffc4ef8220df33f5a11e83fdf45f6fde212296
Author: ByronHsu <by...@gmail.com>
AuthorDate: Mon May 10 21:12:32 2021 +0800

    SUBMARINE-817. Add model management e2e test
    
    * add e2e test for models
    
    * fix linting error
    
    * fix format
    
    * avoid used port
    
    * revert validation.py
---
 .github/workflows/python.yml                       |  2 +
 .../pysubmarine/submarine/models/client.py         |  7 +--
 .../pysubmarine/tests/models/test_model_e2e.py     | 54 ++++++++++++++++++++++
 3 files changed, 60 insertions(+), 3 deletions(-)

diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index 283754b..283a5c6 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -82,6 +82,8 @@ jobs:
           kubectl get pods
           kubectl port-forward svc/submarine-database 3306:3306 &
           kubectl port-forward svc/submarine-server 8080:8080 &
+          kubectl port-forward svc/submarine-minio-service 9000:9000 &
+          kubectl port-forward svc/submarine-mlflow-service 5001:5000 &
       - name: Setup python environment
         uses: actions/setup-python@v1
         with:
diff --git a/submarine-sdk/pysubmarine/submarine/models/client.py b/submarine-sdk/pysubmarine/submarine/models/client.py
index 7214384..7bf6495 100644
--- a/submarine-sdk/pysubmarine/submarine/models/client.py
+++ b/submarine-sdk/pysubmarine/submarine/models/client.py
@@ -25,14 +25,15 @@ from .constant import (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY,
 
 class ModelsClient():
 
-    def __init__(self):
+    def __init__(self, tracking_uri=None, registry_uri=None):
         """
         Set up mlflow server connection, including: s3 endpoint, aws, tracking server
         """
-        os.environ["MLFLOW_S3_ENDPOINT_URL"] = MLFLOW_S3_ENDPOINT_URL
+        os.environ[
+            "MLFLOW_S3_ENDPOINT_URL"] = registry_uri or MLFLOW_S3_ENDPOINT_URL
         os.environ["AWS_ACCESS_KEY_ID"] = AWS_ACCESS_KEY_ID
         os.environ["AWS_SECRET_ACCESS_KEY"] = AWS_SECRET_ACCESS_KEY
-        os.environ["MLFLOW_TRACKING_URI"] = MLFLOW_TRACKING_URI
+        os.environ["MLFLOW_TRACKING_URI"] = tracking_uri or MLFLOW_TRACKING_URI
         self._client = MlflowClient()
 
     def log_model(self, name, checkpoint):
diff --git a/submarine-sdk/pysubmarine/tests/models/test_model_e2e.py b/submarine-sdk/pysubmarine/tests/models/test_model_e2e.py
new file mode 100644
index 0000000..63e1d6d
--- /dev/null
+++ b/submarine-sdk/pysubmarine/tests/models/test_model_e2e.py
@@ -0,0 +1,54 @@
+"""
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements.  See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership.  The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License.  You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied.  See the License for the
+ specific language governing permissions and limitations
+ under the License.
+"""
+
+import os
+
+import numpy as np
+import pytest
+
+from pytorch import LinearNNModel
+from submarine import ModelsClient
+from submarine.models import constant
+
+
+@pytest.fixture(name="models_client", scope="class")
+def models_client_fixture():
+    client = ModelsClient("http://localhost:5001", "http://localhost:9000")
+    return client
+
+
+@pytest.mark.e2e
+class TestSubmarineModelsClientE2E():
+
+    def test_model(self, models_client):
+        model = LinearNNModel()
+        # log
+        name = "simple-nn-model"
+        models_client.log_model(name, model)
+        # update
+        new_name = "new-simple-nn-model"
+        models_client.update_model(name, new_name)
+        # load
+        name = new_name
+        version = "1"
+        model = models_client.load_model(name, version)
+        x = np.float32([[1.0], [2.0]])
+        y = model.predict(x)
+        assert y.shape[0] == 2
+        assert y.shape[1] == 1
+        # delete
+        models_client.delete_model(name, '1')

---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org