You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@submarine.apache.org by pi...@apache.org on 2022/06/29 14:45:33 UTC
[submarine] branch master updated: SUBMARINE-1255. Add XGBoost support to MLJob
This is an automated email from the ASF dual-hosted git repository.
pingsutw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/submarine.git
The following commit(s) were added to refs/heads/master by this push:
new e7ed4a5b SUBMARINE-1255. Add XGBoost support to MLJob
e7ed4a5b is described below
commit e7ed4a5b182a7cc3cad32f406a8327e6c788d966
Author: Kevin Su <pi...@apache.org>
AuthorDate: Wed Jun 29 01:10:50 2022 +0800
SUBMARINE-1255. Add XGBoost support to MLJob
### What is this PR for?
Add XGBoost support to MLJob
### What type of PR is it?
Feature
### Todos
* [x] update ExperimentSpecParser
* [x] create XGBoost Folder in Model Folder
* [ ] add XGBoost button in Workbench
### What is the Jira issue?
https://issues.apache.org/jira/browse/SUBMARINE-1255
### How should this be tested?
### Screenshots (if appropriate)
### Questions:
* Do the license files need updating? No
* Are there breaking changes for older versions? No
* Does this need new documentation? No
Author: Kevin Su <pi...@apache.org>
Author: JackLiu00521 <ja...@gmail.com>
Signed-off-by: Kevin <pi...@apache.org>
Closes #939 from JackLiu00521/SUBMARINE-1255 and squashes the following commits:
a2f0b2aa [Kevin Su] Fixed tests
c270fd33 [Kevin Su] Fixed tests
02b75b88 [Kevin Su] Fixed tests
f3ba3e18 [Kevin Su] Fixed tests
4eab29b8 [Kevin Su] Fixed tests
edc1da70 [Kevin Su] Fixed tests
e30c6e05 [Kevin Su] Fixed tests
832c5533 [Kevin Su] Fixed tests
40484fd9 [Kevin Su] Fixed tests
b29b559b [Kevin Su] Fixed tests
b8ad1d8e [Kevin Su] Fixed tests
778b6ba8 [Kevin Su] Merge branch 'master' of github.com:apache/submarine into SUBMARINE-1255
f47dfdf8 [Kevin Su] Fixed tests
c5ebd478 [JackLiu00521] Merge pull request #3 from pingsutw/939
30b69615 [Kevin Su] Few updates
304724cb [Kevin Su] Few updates
785ce9bf [JackLiu00521] fix errors
41d24f0f [JackLiu00521] 20220627 temp commit
301b2101 [JackLiu00521] temp commit
5feb1686 [JackLiu00521] temp commit before buidl
85435bb3 [JackLiu00521] SUBMARINE-1255. update experimentSpectParserTest
8625ed00 [JackLiu00521] update indentation
d24aad83 [JackLiu00521] update parseXGBoostJobSpec
bca3c60d [JackLiu00521] add xgboost model, update parser
830ab55d [JackLiu00521] commit after fetch upstream
---
.github/workflows/master.yml | 2 +-
.github/workflows/python.yml | 4 +-
helm-charts/submarine/templates/rbac.yaml | 1 +
pom.xml | 2 +-
.../submarine/submarine-observer-rbac.yaml | 4 +-
.../artifacts/submarine/submarine-rbac.yaml | 4 +
.../artifacts/submarine-observer-rbac.yaml | 2 +
.../artifacts/submarine-server-rbac.yaml | 4 +
.../server/api/common/CustomResourceType.java | 8 +-
.../submarine/server/api/spec/ExperimentMeta.java | 3 +-
.../server/api/spec/ExperimentTaskSpec.java | 2 +-
.../server/internal/InternalServiceManager.java | 32 +++---
.../server/submitter/k8s/K8sSubmitter.java | 123 ++++++++++++++-------
.../server/submitter/k8s/client/K8sClient.java | 4 +
.../submitter/k8s/client/K8sDefaultClient.java | 16 ++-
.../submitter/k8s/model/xgboostjob/XGBoostJob.java | 61 ++++++++++
.../k8s/model/xgboostjob/XGBoostJobList.java | 60 ++++++++++
.../model/xgboostjob/XGBoostJobReplicaType.java | 57 ++++++++++
.../k8s/model/xgboostjob/XGBoostJobSpec.java | 58 ++++++++++
.../submitter/k8s/parser/ExperimentSpecParser.java | 50 ++++++++-
.../submitter/k8s/ExperimentSpecParserTest.java | 52 ++++++++-
.../server/submitter/k8s/SpecBuilder.java | 1 +
.../server/submitter/k8s/client/K8sMockClient.java | 7 ++
.../src/test/resources/xgboost_job_req.json | 26 +++++
24 files changed, 511 insertions(+), 72 deletions(-)
diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml
index edbb12bc..dcaafecc 100644
--- a/.github/workflows/master.yml
+++ b/.github/workflows/master.yml
@@ -400,7 +400,7 @@ jobs:
TEST_MODULES: "-pl :submarine-server-database"
run: |
echo ">>> mvn $TEST_FLAG $TEST_MODULES -B"
- mvn $TEST_FLAG $TEST_MODULES -B
+ mvn $TEST_FLAG $TEST_MODULES -B
- name: Build submarine-server-database
env:
MODULES: "-pl :submarine-server-core"
diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index ca88e49f..ac263a6d 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -29,7 +29,7 @@ jobs:
- uses: actions/checkout@v2
- name: Install dependencies
run: |
- pip install --upgrade pip
+ pip install --upgrade pip
pip install -r ./dev-support/style-check/python/lint-requirements.txt
pip install -r ./dev-support/style-check/python/mypy-requirements.txt
- name: Check python sdk code style
@@ -51,7 +51,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install requirements
run: |
- pip install --upgrade pip
+ pip install --upgrade pip
pip install -r ./submarine-sdk/pysubmarine/github-actions/test-requirements.txt
- name: Install pysubmarine with tf1 and pytorch
if: matrix.tf-version == '1.15.0'
diff --git a/helm-charts/submarine/templates/rbac.yaml b/helm-charts/submarine/templates/rbac.yaml
index 9bba6547..c6a07dbd 100644
--- a/helm-charts/submarine/templates/rbac.yaml
+++ b/helm-charts/submarine/templates/rbac.yaml
@@ -52,6 +52,7 @@ rules:
- notebooks
- pytorchjobs
- tfjobs
+ - xgboostjobs
verbs:
- "*"
- apiGroups:
diff --git a/pom.xml b/pom.xml
index dbd75ed1..f21c1d5b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -502,7 +502,7 @@
<plugin>
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
- <version>0.8.0</version>
+ <version>0.8.7</version>
<configuration>
<destFile>${basedir}/target/jacoco.exec</destFile>
<dataFile>${basedir}/target/jacoco.exec</dataFile>
diff --git a/submarine-cloud-v2/artifacts/submarine/submarine-observer-rbac.yaml b/submarine-cloud-v2/artifacts/submarine/submarine-observer-rbac.yaml
index c3cd43c5..aa158747 100644
--- a/submarine-cloud-v2/artifacts/submarine/submarine-observer-rbac.yaml
+++ b/submarine-cloud-v2/artifacts/submarine/submarine-observer-rbac.yaml
@@ -11,6 +11,8 @@ rules:
- tfjobs/status
- pytorchjobs
- pytorchjobs/status
+ - xgboostjobs
+ - xgboostjobs/status
- notebooks
- notebooks/status
verbs:
@@ -41,4 +43,4 @@ subjects:
roleRef:
kind: Role
name: "submarine-observer"
- apiGroup: rbac.authorization.k8s.io
\ No newline at end of file
+ apiGroup: rbac.authorization.k8s.io
diff --git a/submarine-cloud-v2/artifacts/submarine/submarine-rbac.yaml b/submarine-cloud-v2/artifacts/submarine/submarine-rbac.yaml
index 0429e9c8..c97d4d71 100644
--- a/submarine-cloud-v2/artifacts/submarine/submarine-rbac.yaml
+++ b/submarine-cloud-v2/artifacts/submarine/submarine-rbac.yaml
@@ -29,6 +29,8 @@ rules:
- tfjobs/status
- pytorchjobs
- pytorchjobs/status
+ - xgboostjobs
+ - xgboostjobs/status
- notebooks
- notebooks/status
verbs:
@@ -112,6 +114,8 @@ rules:
- tfjobs/status
- pytorchjobs
- pytorchjobs/status
+ - xgboostjobs
+ - xgboostjobs/status
- notebooks
- notebooks/status
verbs:
diff --git a/submarine-cloud-v3/artifacts/submarine-observer-rbac.yaml b/submarine-cloud-v3/artifacts/submarine-observer-rbac.yaml
index 60771481..5779964a 100644
--- a/submarine-cloud-v3/artifacts/submarine-observer-rbac.yaml
+++ b/submarine-cloud-v3/artifacts/submarine-observer-rbac.yaml
@@ -28,6 +28,8 @@ rules:
- tfjobs/status
- pytorchjobs
- pytorchjobs/status
+ - xgboostjobs
+ - xgboostjobs/status
- notebooks
- notebooks/status
verbs:
diff --git a/submarine-cloud-v3/artifacts/submarine-server-rbac.yaml b/submarine-cloud-v3/artifacts/submarine-server-rbac.yaml
index bf59bd3c..9e9b67e5 100644
--- a/submarine-cloud-v3/artifacts/submarine-server-rbac.yaml
+++ b/submarine-cloud-v3/artifacts/submarine-server-rbac.yaml
@@ -28,6 +28,8 @@ rules:
- tfjobs/status
- pytorchjobs
- pytorchjobs/status
+ - xgboostjobs
+ - xgboostjobs/status
- notebooks
- notebooks/status
verbs:
@@ -110,6 +112,8 @@ rules:
- tfjobs/status
- pytorchjobs
- pytorchjobs/status
+ - xgboostjobs
+ - xgboostjobs/status
- notebooks
- notebooks/status
verbs:
diff --git a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/common/CustomResourceType.java b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/common/CustomResourceType.java
index e2873b97..db331ab1 100644
--- a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/common/CustomResourceType.java
+++ b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/common/CustomResourceType.java
@@ -20,13 +20,13 @@
package org.apache.submarine.server.api.common;
public enum CustomResourceType {
- TFJob("TFJob"), PyTorchJob("PyTorchJob"), Notebook("Notebook");
+ TFJob("TFJob"), PyTorchJob("PyTorchJob"), XGBoost("XGBoost"), Notebook("Notebook");
private String customResourceType;
-
+
CustomResourceType(String customResourceType) {
- this.customResourceType = customResourceType;
+ this.customResourceType = customResourceType;
}
-
+
public String getCustomResourceType() {
return this.customResourceType;
}
diff --git a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentMeta.java b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentMeta.java
index c8dfd86b..6e8d3dd2 100644
--- a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentMeta.java
+++ b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentMeta.java
@@ -150,7 +150,8 @@ public class ExperimentMeta {
*/
public enum SupportedMLFramework {
TENSORFLOW("tensorflow"),
- PYTORCH("pytorch");
+ PYTORCH("pytorch"),
+ XGBOOST("xgboost");
private final String name;
diff --git a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentTaskSpec.java b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentTaskSpec.java
index 7980bf8d..8de561e9 100644
--- a/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentTaskSpec.java
+++ b/submarine-server/server-api/src/main/java/org/apache/submarine/server/api/spec/ExperimentTaskSpec.java
@@ -37,7 +37,7 @@ public class ExperimentTaskSpec {
private String cmd;
private Map<String, String> envVars;
- // should ignored in JSON Serialization
+ // should be ignored in JSON Serialization
private Map<String, String> resourceMap;
public ExperimentTaskSpec() {
diff --git a/submarine-server/server-core/src/main/java/org/apache/submarine/server/internal/InternalServiceManager.java b/submarine-server/server-core/src/main/java/org/apache/submarine/server/internal/InternalServiceManager.java
index ea61d5d9..25624878 100644
--- a/submarine-server/server-core/src/main/java/org/apache/submarine/server/internal/InternalServiceManager.java
+++ b/submarine-server/server-core/src/main/java/org/apache/submarine/server/internal/InternalServiceManager.java
@@ -38,38 +38,40 @@ public class InternalServiceManager {
private static volatile InternalServiceManager internalServiceManager;
private static final Logger LOG = LoggerFactory.getLogger(InternalServiceManager.class);
private final ExperimentService experimentService;
- private final NotebookService notebookService;
-
+ private final NotebookService notebookService;
+
public static InternalServiceManager getInstance() {
if (internalServiceManager == null) {
internalServiceManager = new InternalServiceManager(new ExperimentService(), new NotebookService());
}
return internalServiceManager;
}
-
+
@VisibleForTesting
protected InternalServiceManager(ExperimentService experimentService, NotebookService notebookService) {
this.experimentService = experimentService;
this.notebookService = notebookService;
}
-
+
public boolean updateCRStatus(CustomResourceType crType, String resourceId,
Map<String, Object> updateObject) {
if (crType.equals(CustomResourceType.Notebook)) {
return updateNotebookStatus(resourceId, updateObject);
- } else if (crType.equals(CustomResourceType.TFJob) || crType.equals(CustomResourceType.PyTorchJob)) {
+ } else if (crType.equals(CustomResourceType.TFJob)
+ || crType.equals(CustomResourceType.PyTorchJob)
+ || crType.equals(CustomResourceType.XGBoost)) {
return updateExperimentStatus(resourceId, updateObject);
}
return false;
}
-
+
private boolean updateExperimentStatus(String resourceId, Map<String, Object> updateObject) {
ExperimentEntity experimentEntity = experimentService.select(resourceId);
if (experimentEntity == null) {
throw new SubmarineRuntimeException(Status.NOT_FOUND.getStatusCode(),
String.format("cannot find experiment with id:%s", resourceId));
}
-
+
if (updateObject.get("status") != null) {
experimentEntity.setExperimentStatus(updateObject.get("status").toString());
}
@@ -79,35 +81,35 @@ public class InternalServiceManager {
}
if (updateObject.get("createdTime") != null) {
experimentEntity.setCreateTime(
- DateTime.parse(updateObject.get("createdTime").toString()).toDate());
+ DateTime.parse(updateObject.get("createdTime").toString()).toDate());
}
if (updateObject.get("runningTime") != null) {
experimentEntity.setRunningTime(
- DateTime.parse(updateObject.get("runningTime").toString()).toDate());
+ DateTime.parse(updateObject.get("runningTime").toString()).toDate());
}
if (updateObject.get("finishedTime") != null) {
experimentEntity.setFinishedTime(
- DateTime.parse(updateObject.get("finishedTime").toString()).toDate());
+ DateTime.parse(updateObject.get("finishedTime").toString()).toDate());
}
-
+
return experimentService.update(experimentEntity);
}
-
+
private boolean updateNotebookStatus(String resourceId, Map<String, Object> updateObject) {
Notebook notebook = notebookService.select(resourceId);
if (notebook == null) {
throw new SubmarineRuntimeException(Status.NOT_FOUND.getStatusCode(),
String.format("cannot find notebook with id:%s", resourceId));
}
-
+
if (updateObject.containsKey("status")) {
notebook.setStatus(updateObject.get("status").toString());
}
-
+
if (updateObject.get("createTime") != null) {
notebook.setCreatedTime(updateObject.get("createTime").toString());
}
-
+
if (updateObject.get("deletedTime") != null) {
notebook.setDeletedTime(updateObject.get("deletedTime").toString());
}
diff --git a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/K8sSubmitter.java b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/K8sSubmitter.java
index daaf012f..d696dce5 100644
--- a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/K8sSubmitter.java
+++ b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/K8sSubmitter.java
@@ -73,6 +73,7 @@ import org.apache.submarine.server.submitter.k8s.model.mljob.MLJob;
import org.apache.submarine.server.submitter.k8s.model.notebook.NotebookCR;
import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJob;
import org.apache.submarine.server.submitter.k8s.model.pytorchjob.PyTorchJob;
+import org.apache.submarine.server.submitter.k8s.model.xgboostjob.XGBoostJob;
import org.apache.submarine.server.submitter.k8s.parser.ExperimentSpecParser;
import org.apache.submarine.server.submitter.k8s.util.MLJobConverter;
import org.apache.submarine.server.submitter.k8s.util.NotebookUtils;
@@ -90,6 +91,7 @@ public class K8sSubmitter implements Submitter {
private static final String TF_JOB_SELECTOR_KEY = "tf-job-name=";
private static final String PYTORCH_JOB_SELECTOR_KEY = "pytorch-job-name=";
+ private static final String XGBoost_JOB_SELECTOR_KEY = "xgboost-job-name=";
// Add an exception Consumer, handle the problem that delete operation does not have the resource
public static final Function<ApiException, Object> API_EXCEPTION_404_CONSUMER = e -> {
@@ -193,26 +195,39 @@ public class K8sSubmitter implements Submitter {
MLJob mlJob = ExperimentSpecParser.parseJob(spec);
mlJob.getMetadata().setNamespace(getServerNamespace());
mlJob.getMetadata().setOwnerReferences(OwnerReferenceUtils.getOwnerReference());
- AgentPod agentPod = new AgentPod(getServerNamespace(), spec.getMeta().getName(),
- mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
- ? CustomResourceType.TFJob : CustomResourceType.PyTorchJob,
- spec.getMeta().getExperimentId());
+ CustomResourceType customResourceType;
+ if (mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)) {
+ customResourceType = CustomResourceType.TFJob;
+ } else if (mlJob.getPlural().equals(XGBoostJob.CRD_XGBOOST_PLURAL_V1)) {
+ customResourceType = CustomResourceType.XGBoost;
+ } else {
+ customResourceType = CustomResourceType.PyTorchJob;
+ }
+
+ AgentPod agentPod = new AgentPod(getServerNamespace(), spec.getMeta().getName(), customResourceType,
+ spec.getMeta().getExperimentId());
- Object object = mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
- ? k8sClient.getTfJobClient().create(getServerNamespace(), (TFJob) mlJob,
- new CreateOptions()).throwsApiException().getObject()
- : k8sClient.getPyTorchJobClient().create(getServerNamespace(), (PyTorchJob) mlJob,
- new CreateOptions()).throwsApiException().getObject();
+ Object object;
+ if (mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)) {
+ object = k8sClient.getTfJobClient().create(getServerNamespace(), (TFJob) mlJob,
+ new CreateOptions()).throwsApiException().getObject();
+ } else if (mlJob.getPlural().equals(XGBoostJob.CRD_XGBOOST_PLURAL_V1)) {
+ object = k8sClient.getXGBoostJobClient().create(getServerNamespace(), (XGBoostJob) mlJob,
+ new CreateOptions()).throwsApiException().getObject();
+ } else {
+ object = k8sClient.getPyTorchJobClient().create(getServerNamespace(), (PyTorchJob) mlJob,
+ new CreateOptions()).throwsApiException().getObject();
+ }
- V1Pod agentPodResult = k8sClient.getPodClient().create(agentPod).throwsApiException().getObject();
+ k8sClient.getPodClient().create(agentPod).throwsApiException().getObject();
experiment = parseExperimentResponseObject(object, ParseOp.PARSE_OP_RESULT);
} catch (InvalidSpecException e) {
LOG.error("K8s submitter: parse Job object failed by " + e.getMessage(), e);
throw new SubmarineRuntimeException(400, e.getMessage());
} catch (ApiException e) {
- LOG.error("K8s submitter: parse Job object failed by " + e.getMessage(), e);
- throw new SubmarineRuntimeException(e.getCode(), "K8s submitter: parse Job object failed by " +
+ LOG.error("K8s submitter: failed to create pod " + e.getMessage(), e);
+ throw new SubmarineRuntimeException(e.getCode(), "K8s submitter: failed to create pod " +
e.getMessage());
}
return experiment;
@@ -225,11 +240,18 @@ public class K8sSubmitter implements Submitter {
MLJob mlJob = ExperimentSpecParser.parseJob(spec);
mlJob.getMetadata().setNamespace(getServerNamespace());
- Object object = mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
- ? k8sClient.getTfJobClient().get(getServerNamespace(), mlJob.getMetadata().getName())
- .throwsApiException().getObject()
- : k8sClient.getPyTorchJobClient().get(getServerNamespace(), mlJob.getMetadata().getName())
- .throwsApiException().getObject();
+
+ Object object;
+ if (mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)) {
+ object = k8sClient.getTfJobClient().get(getServerNamespace(),
+ mlJob.getMetadata().getName()).throwsApiException().getObject();
+ } else if (mlJob.getPlural().equals(XGBoostJob.CRD_XGBOOST_PLURAL_V1)) {
+ object = k8sClient.getXGBoostJobClient().get(getServerNamespace(),
+ mlJob.getMetadata().getName()).throwsApiException().getObject();
+ } else {
+ object = k8sClient.getPyTorchJobClient().get(getServerNamespace(),
+ mlJob.getMetadata().getName()).throwsApiException().getObject();
+ }
experiment = parseExperimentResponseObject(object, ParseOp.PARSE_OP_RESULT);
@@ -253,16 +275,24 @@ public class K8sSubmitter implements Submitter {
PatchOptions patchOptions = new PatchOptions();
patchOptions.setFieldManager(spec.getMeta().getExperimentId());
patchOptions.setForce(true);
- Object object = mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
- ? k8sClient.getTfJobClient().patch(getServerNamespace(), mlJob.getMetadata().getName(),
- V1Patch.PATCH_FORMAT_APPLY_YAML,
- new V1Patch(new Gson().toJson(mlJob)),
- patchOptions).throwsApiException().getObject()
- : k8sClient.getPyTorchJobClient().patch(getServerNamespace(), mlJob.getMetadata().getName(),
- V1Patch.PATCH_FORMAT_APPLY_YAML,
- new V1Patch(new Gson().toJson(mlJob)),
- patchOptions).throwsApiException().getObject()
- ;
+ Object object;
+ if (mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)) {
+ object = k8sClient.getTfJobClient().patch(getServerNamespace(), mlJob.getMetadata().getName(),
+ V1Patch.PATCH_FORMAT_APPLY_YAML,
+ new V1Patch(new Gson().toJson(mlJob)),
+ patchOptions).throwsApiException().getObject();
+ } else if (mlJob.getPlural().equals(XGBoostJob.CRD_XGBOOST_PLURAL_V1)) {
+ object = k8sClient.getXGBoostJobClient().patch(getServerNamespace(), mlJob.getMetadata().getName(),
+ V1Patch.PATCH_FORMAT_APPLY_YAML,
+ new V1Patch(new Gson().toJson(mlJob)),
+ patchOptions).throwsApiException().getObject();
+ } else {
+ object = k8sClient.getPyTorchJobClient().patch(getServerNamespace(), mlJob.getMetadata().getName(),
+ V1Patch.PATCH_FORMAT_APPLY_YAML,
+ new V1Patch(new Gson().toJson(mlJob)),
+ patchOptions).throwsApiException().getObject();
+ }
+
experiment = parseExperimentResponseObject(object, ParseOp.PARSE_OP_RESULT);
} catch (InvalidSpecException e) {
throw new SubmarineRuntimeException(409, e.getMessage());
@@ -281,18 +311,31 @@ public class K8sSubmitter implements Submitter {
MLJob mlJob = ExperimentSpecParser.parseJob(spec);
mlJob.getMetadata().setNamespace(getServerNamespace());
- AgentPod agentPod = new AgentPod(getServerNamespace(), spec.getMeta().getName(),
- mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
- ? CustomResourceType.TFJob : CustomResourceType.PyTorchJob,
+ CustomResourceType customResourceType;
+ if (mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)) {
+ customResourceType = CustomResourceType.TFJob;
+ } else if (mlJob.getPlural().equals(XGBoostJob.CRD_XGBOOST_PLURAL_V1)) {
+ customResourceType = CustomResourceType.XGBoost;
+ } else {
+ customResourceType = CustomResourceType.PyTorchJob;
+ }
+
+ AgentPod agentPod = new AgentPod(getServerNamespace(), spec.getMeta().getName(), customResourceType,
spec.getMeta().getExperimentId());
- Object object = mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)
- ? k8sClient.getTfJobClient().delete(getServerNamespace(), mlJob.getMetadata().getName(),
- MLJobConverter.toDeleteOptionsFromMLJob(mlJob))
- .throwsApiException().getStatus()
- : k8sClient.getPyTorchJobClient().delete(getServerNamespace(), mlJob.getMetadata().getName(),
- MLJobConverter.toDeleteOptionsFromMLJob(mlJob))
- .throwsApiException().getStatus();
+ Object object;
+ if (mlJob.getPlural().equals(TFJob.CRD_TF_PLURAL_V1)) {
+ object = k8sClient.getTfJobClient().delete(getServerNamespace(), mlJob.getMetadata().getName(),
+ MLJobConverter.toDeleteOptionsFromMLJob(mlJob)).throwsApiException().getStatus();
+ } else if (mlJob.getPlural().equals(XGBoostJob.CRD_XGBOOST_PLURAL_V1)) {
+ object = k8sClient.getXGBoostJobClient().delete(getServerNamespace(), mlJob.getMetadata().getName(),
+ MLJobConverter.toDeleteOptionsFromMLJob(mlJob))
+ .throwsApiException().getStatus();
+ } else {
+ object = k8sClient.getPyTorchJobClient().delete(getServerNamespace(), mlJob.getMetadata().getName(),
+ MLJobConverter.toDeleteOptionsFromMLJob(mlJob))
+ .throwsApiException().getStatus();
+ }
LOG.info(String.format("Experiment:%s had been deleted, start to delete agent pod:%s",
spec.getMeta().getName(), agentPod.getMetadata().getName()));
@@ -539,7 +582,11 @@ public class K8sSubmitter implements Submitter {
if (experimentSpec.getMeta().getFramework()
.equalsIgnoreCase(ExperimentMeta.SupportedMLFramework.TENSORFLOW.getName())) {
return TF_JOB_SELECTOR_KEY + experimentSpec.getMeta().getExperimentId();
- } else {
+ } else if (experimentSpec.getMeta().getFramework()
+ .equalsIgnoreCase(ExperimentMeta.SupportedMLFramework.XGBOOST.getName())) {
+ return XGBoost_JOB_SELECTOR_KEY + experimentSpec.getMeta().getExperimentId();
+ }
+ else {
return PYTORCH_JOB_SELECTOR_KEY + experimentSpec.getMeta().getExperimentId();
}
}
diff --git a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/client/K8sClient.java b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/client/K8sClient.java
index 165a19b2..02d56b21 100644
--- a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/client/K8sClient.java
+++ b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/client/K8sClient.java
@@ -42,6 +42,8 @@ import org.apache.submarine.server.submitter.k8s.model.pytorchjob.PyTorchJob;
import org.apache.submarine.server.submitter.k8s.model.pytorchjob.PyTorchJobList;
import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJob;
import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJobList;
+import org.apache.submarine.server.submitter.k8s.model.xgboostjob.XGBoostJob;
+import org.apache.submarine.server.submitter.k8s.model.xgboostjob.XGBoostJobList;
public interface K8sClient {
@@ -66,6 +68,8 @@ public interface K8sClient {
GenericKubernetesApi<PyTorchJob, PyTorchJobList> getPyTorchJobClient();
+ GenericKubernetesApi<XGBoostJob, XGBoostJobList> getXGBoostJobClient();
+
GenericKubernetesApi<NotebookCR, NotebookCRList> getNotebookCRClient();
GenericKubernetesApi<SeldonDeployment, SeldonDeploymentList> getSeldonDeploymentClient();
diff --git a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/client/K8sDefaultClient.java b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/client/K8sDefaultClient.java
index 0207e8fc..e335027a 100644
--- a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/client/K8sDefaultClient.java
+++ b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/client/K8sDefaultClient.java
@@ -49,6 +49,8 @@ import org.apache.submarine.server.submitter.k8s.model.pytorchjob.PyTorchJob;
import org.apache.submarine.server.submitter.k8s.model.pytorchjob.PyTorchJobList;
import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJob;
import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJobList;
+import org.apache.submarine.server.submitter.k8s.model.xgboostjob.XGBoostJob;
+import org.apache.submarine.server.submitter.k8s.model.xgboostjob.XGBoostJobList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -75,6 +77,8 @@ public class K8sDefaultClient implements K8sClient {
private final GenericKubernetesApi<PyTorchJob, PyTorchJobList> pyTorchJobClient;
+ private final GenericKubernetesApi<XGBoostJob, XGBoostJobList> xgboostJobClient;
+
private final GenericKubernetesApi<NotebookCR, NotebookCRList> notebookCRClient;
private final GenericKubernetesApi<SeldonDeployment, SeldonDeploymentList> seldonDeploymentClient;
@@ -91,7 +95,7 @@ public class K8sDefaultClient implements K8sClient {
public K8sDefaultClient() {
String path = System.getenv(KUBECONFIG_ENV);
- if (StringUtils.isNotBlank(path)) {//
+ if (StringUtils.isNotBlank(path)) {
try (FileReader reader = new FileReader(path)) {
LOG.info("init by kubeconfig env path {}", path);
KubeConfig config = KubeConfig.loadKubeConfig(reader);
@@ -156,6 +160,12 @@ public class K8sDefaultClient implements K8sClient {
PyTorchJob.CRD_PYTORCH_GROUP_V1, PyTorchJob.CRD_PYTORCH_VERSION_V1,
PyTorchJob.CRD_PYTORCH_PLURAL_V1, client);
+ xgboostJobClient =
+ new GenericKubernetesApi<>(
+ XGBoostJob.class, XGBoostJobList.class,
+ XGBoostJob.CRD_XGBOOST_GROUP_V1, XGBoostJob.CRD_XGBOOST_VERSION_V1,
+ XGBoostJob.CRD_XGBOOST_PLURAL_V1, client);
+
notebookCRClient =
new GenericKubernetesApi<>(
NotebookCR.class, NotebookCRList.class,
@@ -225,6 +235,10 @@ public class K8sDefaultClient implements K8sClient {
return checkApi(pyTorchJobClient, PyTorchJob.class);
}
+ public GenericKubernetesApi<XGBoostJob, XGBoostJobList> getXGBoostJobClient() {
+ return checkApi(xgboostJobClient, XGBoostJob.class);
+ }
+
public GenericKubernetesApi<NotebookCR, NotebookCRList> getNotebookCRClient() {
return checkApi(notebookCRClient, NotebookCR.class);
}
diff --git a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java
new file mode 100644
index 00000000..968ed7bc
--- /dev/null
+++ b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJob.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.submarine.server.submitter.k8s.model.xgboostjob;
+
+import com.google.gson.annotations.SerializedName;
+import org.apache.submarine.server.submitter.k8s.model.mljob.MLJob;
+
+public class XGBoostJob extends MLJob {
+
+ public static final String CRD_XGBOOST_KIND_V1 = "XGBoostJob";
+ public static final String CRD_XGBOOST_PLURAL_V1 = "xgboostjobs";
+ public static final String CRD_XGBOOST_GROUP_V1 = "kubeflow.org";
+ public static final String CRD_XGBOOST_VERSION_V1 = "v1";
+ public static final String CRD_XGBOOST_API_VERSION_V1 = CRD_XGBOOST_GROUP_V1 +
+ "/" + CRD_XGBOOST_VERSION_V1;
+
+ @SerializedName("spec")
+ private XGBoostJobSpec spec;
+
+ public XGBoostJob() {
+ setApiVersion(CRD_XGBOOST_API_VERSION_V1);
+ setKind(CRD_XGBOOST_KIND_V1);
+ setPlural(CRD_XGBOOST_PLURAL_V1);
+ setVersion(CRD_XGBOOST_VERSION_V1);
+ setGroup(CRD_XGBOOST_GROUP_V1);
+ }
+
+ /**
+ * Get the job spec which contains all the info for XGBoostJob.
+ * @return job spec
+ */
+ public XGBoostJobSpec getSpec() {
+ return spec;
+ }
+
+ /**
+ * Set the spec, the entry of the XGBoostJob
+ * @param spec job spec
+ */
+ public void setSpec(XGBoostJobSpec spec) {
+ this.spec = spec;
+ }
+}
+
diff --git a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJobList.java b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJobList.java
new file mode 100644
index 00000000..2691965d
--- /dev/null
+++ b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJobList.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.submarine.server.submitter.k8s.model.xgboostjob;
+
+import com.google.gson.annotations.SerializedName;
+import java.util.List;
+
+import io.kubernetes.client.common.KubernetesListObject;
+import io.kubernetes.client.openapi.models.V1ListMeta;
+
+public class XGBoostJobList implements KubernetesListObject {
+ @SerializedName("apiVersion")
+ private String apiVersion;
+
+ @SerializedName("kind")
+ private String kind;
+
+ @SerializedName("metadata")
+ private V1ListMeta metadata;
+
+ @SerializedName("items")
+ private List<XGBoostJob> items;
+
+ @Override
+ public V1ListMeta getMetadata() {
+ return metadata;
+ }
+
+ @Override
+ public List<XGBoostJob> getItems() {
+ return items;
+ }
+
+ @Override
+ public String getApiVersion() {
+ return XGBoostJob.CRD_XGBOOST_API_VERSION_V1;
+ }
+
+ @Override
+ public String getKind() {
+ return XGBoostJob.CRD_XGBOOST_KIND_V1 + "List";
+ }
+}
diff --git a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJobReplicaType.java b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJobReplicaType.java
new file mode 100644
index 00000000..acbe8d33
--- /dev/null
+++ b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJobReplicaType.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.submarine.server.submitter.k8s.model.xgboostjob;
+
+import com.google.gson.annotations.SerializedName;
+import org.apache.submarine.server.submitter.k8s.model.mljob.MLJobReplicaType;
+
+public enum XGBoostJobReplicaType implements MLJobReplicaType {
+
+ @SerializedName("Master")
+ Master("Master"),
+
+ @SerializedName("Worker")
+ Worker("Worker");
+
+ private final String typeName;
+
+ XGBoostJobReplicaType(String n) {
+ this.typeName = n;
+ }
+
+ public static boolean isSupportedReplicaType(String type) {
+ return type.equalsIgnoreCase("Master") ||
+ type.equalsIgnoreCase("Worker");
+ }
+
+ public static String[] names() {
+ XGBoostJobReplicaType[] types = values();
+ String[] names = new String[types.length];
+ for (int i = 0; i < types.length; i++) {
+ names[i] = types[i].name();
+ }
+ return names;
+ }
+
+ @Override
+ public String getTypeName() {
+ return this.typeName;
+ }
+}
diff --git a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJobSpec.java b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJobSpec.java
new file mode 100644
index 00000000..d519b586
--- /dev/null
+++ b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/model/xgboostjob/XGBoostJobSpec.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.submarine.server.submitter.k8s.model.xgboostjob;
+
+import com.google.gson.annotations.SerializedName;
+import org.apache.submarine.server.submitter.k8s.model.mljob.MLJobReplicaSpec;
+
+import java.util.Map;
+
+/**
+ * The replica spec of XGBoostJob.
+ */
+public class XGBoostJobSpec {
+ /**
+ * Key: Master, Worker
+ */
+ @SerializedName("xgbReplicaSpecs")
+ private Map<XGBoostJobReplicaType, MLJobReplicaSpec> replicaSpecs;
+ @SerializedName("backoffLimit")
+ private Integer backoffLimit = 3;
+
+ /**
+ * Get the replica specs.
+ *
+ * @return map
+ */
+ public Map<XGBoostJobReplicaType, MLJobReplicaSpec> getReplicaSpecs() {
+ return replicaSpecs;
+ }
+
+ /**
+ * Set replica specs
+ *
+ * @param replicaSpecs map
+ */
+ public void setReplicaSpecs(
+ Map<XGBoostJobReplicaType, MLJobReplicaSpec> replicaSpecs) {
+ this.replicaSpecs = replicaSpecs;
+ }
+
+}
diff --git a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/parser/ExperimentSpecParser.java b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/parser/ExperimentSpecParser.java
index 91f57165..94e3c2c4 100644
--- a/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/parser/ExperimentSpecParser.java
+++ b/submarine-server/server-submitter/submitter-k8s/src/main/java/org/apache/submarine/server/submitter/k8s/parser/ExperimentSpecParser.java
@@ -52,6 +52,9 @@ import org.apache.submarine.server.submitter.k8s.model.pytorchjob.PyTorchJobSpec
import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJob;
import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJobReplicaType;
import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJobSpec;
+import org.apache.submarine.server.submitter.k8s.model.xgboostjob.XGBoostJob;
+import org.apache.submarine.server.submitter.k8s.model.xgboostjob.XGBoostJobReplicaType;
+import org.apache.submarine.server.submitter.k8s.model.xgboostjob.XGBoostJobSpec;
import java.util.ArrayList;
import java.util.Arrays;
@@ -70,6 +73,9 @@ public class ExperimentSpecParser {
} else if (ExperimentMeta.SupportedMLFramework.PYTORCH.
getName().equalsIgnoreCase(framework)) {
return parsePyTorchJob(experimentSpec);
+ } else if (ExperimentMeta.SupportedMLFramework.XGBOOST.
+ getName().equalsIgnoreCase(framework)) {
+ return parseXGBoostJob(experimentSpec);
} else {
throw new InvalidSpecException("Unsupported framework name: " + framework +
". Supported frameworks are: " +
@@ -77,6 +83,43 @@ public class ExperimentSpecParser {
}
}
+ public static XGBoostJob parseXGBoostJob(
+ ExperimentSpec experimentSpec) throws InvalidSpecException {
+ XGBoostJob xGBoostJob = new XGBoostJob();
+ xGBoostJob.setMetadata(parseMetadata(experimentSpec));
+ xGBoostJob.setSpec(parseXGBoostJobSpec(experimentSpec));
+ return xGBoostJob;
+ }
+
+ public static XGBoostJobSpec parseXGBoostJobSpec(ExperimentSpec experimentSpec)
+ throws InvalidSpecException {
+ XGBoostJobSpec xGBoostJobSpec = new XGBoostJobSpec();
+
+ Map<XGBoostJobReplicaType, MLJobReplicaSpec> replicaSpecMap = new HashMap<>();
+
+ for (Map.Entry<String, ExperimentTaskSpec> entry : experimentSpec.getSpec().entrySet()) {
+ String replicaType = entry.getKey();
+ ExperimentTaskSpec taskSpec = entry.getValue();
+
+ if (XGBoostJobReplicaType.isSupportedReplicaType(replicaType)) {
+ MLJobReplicaSpec replicaSpec = new MLJobReplicaSpec();
+ replicaSpec.setReplicas(taskSpec.getReplicas());
+ V1PodTemplateSpec podTemplateSpec = parseTemplateSpec(taskSpec, experimentSpec);
+
+ replicaSpec.setTemplate(podTemplateSpec);
+ replicaSpecMap.put(XGBoostJobReplicaType.valueOf(replicaType), replicaSpec);
+ } else {
+ throw new InvalidSpecException("Unrecognized replica type name: " +
+ entry.getKey() +
+ ", it should be " +
+ String.join(",", XGBoostJobReplicaType.names()) +
+ " for XGBoost experiment.");
+ }
+ }
+ xGBoostJobSpec.setReplicaSpecs(replicaSpecMap);
+ return xGBoostJobSpec;
+ }
+
public static PyTorchJob parsePyTorchJob(
ExperimentSpec experimentSpec) throws InvalidSpecException {
PyTorchJob pyTorchJob = new PyTorchJob();
@@ -97,7 +140,7 @@ public class ExperimentSpecParser {
MLJobReplicaSpec replicaSpec = new MLJobReplicaSpec();
replicaSpec.setReplicas(taskSpec.getReplicas());
V1PodTemplateSpec podTemplateSpec = parseTemplateSpec(taskSpec, experimentSpec);
-
+
replicaSpec.setTemplate(podTemplateSpec);
replicaSpecMap.put(PyTorchJobReplicaType.valueOf(replicaType), replicaSpec);
} else {
@@ -137,12 +180,12 @@ public class ExperimentSpecParser {
for (Map.Entry<String, ExperimentTaskSpec> entry : experimentSpec.getSpec().entrySet()) {
String replicaType = entry.getKey();
ExperimentTaskSpec taskSpec = entry.getValue();
-
+
if (TFJobReplicaType.isSupportedReplicaType(replicaType)) {
MLJobReplicaSpec replicaSpec = new MLJobReplicaSpec();
replicaSpec.setReplicas(taskSpec.getReplicas());
V1PodTemplateSpec podTemplateSpec = parseTemplateSpec(taskSpec, experimentSpec);
-
+
replicaSpec.setTemplate(podTemplateSpec);
replicaSpecMap.put(TFJobReplicaType.valueOf(replicaType), replicaSpec);
} else {
@@ -165,7 +208,6 @@ public class ExperimentSpecParser {
V1Container container = new V1Container();
container.setName(experimentSpec.getMeta().getFramework().toLowerCase());
// image
-
if (taskSpec.getImage() != null) {
container.setImage(taskSpec.getImage());
} else {
diff --git a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java
index 3c7fa686..4e2deb28 100644
--- a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java
+++ b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/ExperimentSpecParserTest.java
@@ -43,6 +43,8 @@ import org.apache.submarine.server.submitter.k8s.model.pytorchjob.PyTorchJob;
import org.apache.submarine.server.submitter.k8s.model.pytorchjob.PyTorchJobReplicaType;
import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJob;
import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJobReplicaType;
+import org.apache.submarine.server.submitter.k8s.model.xgboostjob.XGBoostJob;
+import org.apache.submarine.server.submitter.k8s.model.xgboostjob.XGBoostJobReplicaType;
import org.apache.submarine.server.submitter.k8s.parser.ExperimentSpecParser;
import org.apache.submarine.server.submitter.k8s.experiment.codelocalizer.AbstractCodeLocalizer;
import org.apache.submarine.server.submitter.k8s.experiment.codelocalizer.GitCodeLocalizer;
@@ -148,6 +150,47 @@ public class ExperimentSpecParserTest extends SpecBuilder {
}
}
+ @Test
+ public void testValidXGBoostExperiment() throws IOException,
+ URISyntaxException, InvalidSpecException {
+ ExperimentSpec experimentSpec = (ExperimentSpec) buildFromJsonFile(ExperimentSpec.class,
+ xgboostJobReqFile);
+ XGBoostJob xgboostJob = (XGBoostJob) ExperimentSpecParser.parseJob(experimentSpec);
+ validateMetadata(experimentSpec.getMeta(), xgboostJob.getMetadata(),
+ ExperimentMeta.SupportedMLFramework.XGBOOST.getName().toLowerCase()
+ );
+
+ validateReplicaSpec(experimentSpec, xgboostJob, XGBoostJobReplicaType.Master);
+ validateReplicaSpec(experimentSpec, xgboostJob, XGBoostJobReplicaType.Worker);
+ }
+
+ @Test
+ public void testInvalidXGBoostExperiment() throws IOException,
+ URISyntaxException {
+ ExperimentSpec experimentSpec = (ExperimentSpec) buildFromJsonFile(ExperimentSpec.class,
+ xgboostJobReqFile);
+ // Case 1. Invalid framework name
+ experimentSpec.getMeta().setFramework("fooframework");
+ try {
+ ExperimentSpecParser.parseJob(experimentSpec);
+ Assert.fail("It should throw InvalidSpecException");
+ } catch (InvalidSpecException e) {
+ Assert.assertTrue(e.getMessage().contains("Unsupported framework name"));
+ }
+
+ // Case 2. Invalid XGBoost Replica name. It can only be "master" and "worker"
+ experimentSpec = (ExperimentSpec) buildFromJsonFile(ExperimentSpec.class, xgboostJobReqFile);
+ experimentSpec.getSpec().put("foo", experimentSpec.getSpec().get(
+ XGBoostJobReplicaType.Master.getTypeName()));
+ experimentSpec.getSpec().remove(XGBoostJobReplicaType.Master.getTypeName());
+ try {
+ ExperimentSpecParser.parseJob(experimentSpec);
+ Assert.fail("It should throw InvalidSpecException");
+ } catch (InvalidSpecException e) {
+ Assert.assertTrue(e.getMessage().contains("Unrecognized replica type name"));
+ }
+ }
+
private void validateMetadata(ExperimentMeta expectedMeta, V1ObjectMeta actualMeta,
String actualFramework) {
Assert.assertEquals(expectedMeta.getName(), actualMeta.getName());
@@ -160,18 +203,21 @@ public class ExperimentSpecParserTest extends SpecBuilder {
MLJobReplicaSpec mlJobReplicaSpec = null;
if (mlJob instanceof PyTorchJob) {
mlJobReplicaSpec = ((PyTorchJob) mlJob).getSpec().getReplicaSpecs().get(type);
- } else if (mlJob instanceof TFJob){
+ } else if (mlJob instanceof TFJob) {
mlJobReplicaSpec = ((TFJob) mlJob).getSpec().getReplicaSpecs().get(type);
+ } else if (mlJob instanceof XGBoostJob) {
+ mlJobReplicaSpec = ((XGBoostJob) mlJob).getSpec().getReplicaSpecs().get(type);
}
Assert.assertNotNull(mlJobReplicaSpec);
-
+
ExperimentTaskSpec definedPyTorchMasterTask = experimentSpec.getSpec().
get(type.getTypeName());
+
// replica
int expectedMasterReplica = definedPyTorchMasterTask.getReplicas();
Assert.assertEquals(expectedMasterReplica,
(int) mlJobReplicaSpec.getReplicas());
- // Image
+ // image
String expectedMasterImage = definedPyTorchMasterTask.getImage() == null ?
experimentSpec.getEnvironment().getImage() : definedPyTorchMasterTask.getImage();
String actualMasterImage = mlJobReplicaSpec.getContainerImageName();
diff --git a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/SpecBuilder.java b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/SpecBuilder.java
index 573c295d..96f89082 100644
--- a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/SpecBuilder.java
+++ b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/SpecBuilder.java
@@ -36,6 +36,7 @@ public abstract class SpecBuilder {
// The spec files in test/resources
protected final String tfJobReqFile = "/tf_mnist_req.json";
protected final String pytorchJobReqFile = "/pytorch_job_req.json";
+ protected final String xgboostJobReqFile = "/xgboost_job_req.json";
protected final String pytorchJobWithEnvReqFile = "/pytorch_job_req_env.json";
protected final String pytorchJobWithInvalidEnvReqFile =
"/pytorch_job_req_invalid_env.json";
diff --git a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/client/K8sMockClient.java b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/client/K8sMockClient.java
index 82b18321..684b30ca 100644
--- a/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/client/K8sMockClient.java
+++ b/submarine-server/server-submitter/submitter-k8s/src/test/java/org/apache/submarine/server/submitter/k8s/client/K8sMockClient.java
@@ -46,6 +46,8 @@ import org.apache.submarine.server.submitter.k8s.model.pytorchjob.PyTorchJob;
import org.apache.submarine.server.submitter.k8s.model.pytorchjob.PyTorchJobList;
import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJob;
import org.apache.submarine.server.submitter.k8s.model.tfjob.TFJobList;
+import org.apache.submarine.server.submitter.k8s.model.xgboostjob.XGBoostJob;
+import org.apache.submarine.server.submitter.k8s.model.xgboostjob.XGBoostJobList;
import java.io.File;
import java.io.IOException;
@@ -198,6 +200,11 @@ public class K8sMockClient implements K8sClient {
return null;
}
+ @Override
+ public GenericKubernetesApi<XGBoostJob, XGBoostJobList> getXGBoostJobClient() {
+ return null;
+ }
+
@Override
public GenericKubernetesApi<NotebookCR, NotebookCRList> getNotebookCRClient() {
return notebookCRClient;
diff --git a/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json b/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json
new file mode 100644
index 00000000..c4ba97f4
--- /dev/null
+++ b/submarine-server/server-submitter/submitter-k8s/src/test/resources/xgboost_job_req.json
@@ -0,0 +1,26 @@
+{
+ "meta": {
+ "name": "xgboost-dist-mnist",
+ "namespace": "submarine",
+ "framework": "XGBoost",
+ "cmd": "python /opt/mlkube/main.py --job_type=Train --xgboost_parameter=objective:multi:softprob,num_class:3 --n_estimators=10 --learning_rate=0.1",
+ "envVars": {
+ "ENV_1": "ENV1"
+ }
+ },
+ "environment": {
+ "image": "apache/submarine:xgboost-dist-iris-1.0"
+ },
+ "spec": {
+ "Master": {
+ "name": "master",
+ "replicas": 1,
+ "resources": "cpu=2,memory=2048M"
+ },
+ "Worker": {
+ "name": "worker",
+ "replicas": 2,
+ "resources": "cpu=1,memory=1024M"
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscribe@submarine.apache.org
For additional commands, e-mail: dev-help@submarine.apache.org