You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/07/22 05:31:09 UTC
[1/2] systemml git commit: [SYSTEMML-2420,
2422] New distributed paramserv spark workers and rpc
Repository: systemml
Updated Branches:
refs/heads/master 54dbe9bb2 -> 15ecb723e
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml
deleted file mode 100644
index 8605984..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-epoch.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "EPOCH", batchsize,"DISJOINT_CONTIGUOUS", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml b/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
new file mode 100644
index 0000000..4d0f32e
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-spark-agg-service-failed.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+e1 = "element1"
+modelList = list(e1)
+X = matrix(1, rows=200, cols=30)
+Y = matrix(2, rows=200, cols=1)
+X_val = matrix(3, rows=200, cols=30)
+Y_val = matrix(4, rows=200, cols=1)
+
+gradients = function(matrix[double] features,
+ matrix[double] labels,
+ list[unknown] hyperparams,
+ list[unknown] model)
+ return (list[unknown] gradients) {
+ gradients = model
+}
+
+aggregation = function(list[unknown] model,
+ list[unknown] gradients,
+ list[unknown] hyperparams)
+ return (list[unknown] modelResult) {
+ modelResult = model
+ print(toString(as.matrix(gradients["agg_service_err"])))
+}
+
+e2 = "element2"
+params = list(e2)
+
+modelList = list("model")
+
+# Use paramserv function
+modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="REMOTE_SPARK", utype="BSP", epochs=10, hyperparams=params, k=1)
+
+print(toString(as.matrix(modelList2[1])))
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-spark-nn-bsp-batch-dc.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-spark-nn-bsp-batch-dc.dml b/src/test/scripts/functions/paramserv/paramserv-spark-nn-bsp-batch-dc.dml
deleted file mode 100644
index 31d44aa..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-spark-nn-bsp-batch-dc.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 16
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_CONTIGUOUS", "REMOTE_SPARK")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml b/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
new file mode 100644
index 0000000..ad16122
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-spark-worker-failed.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+e1 = "element1"
+modelList = list(e1)
+X = matrix(1, rows=200, cols=30)
+Y = matrix(2, rows=200, cols=1)
+X_val = matrix(3, rows=200, cols=30)
+Y_val = matrix(4, rows=200, cols=1)
+
+gradients = function(matrix[double] features,
+ matrix[double] labels,
+ list[unknown] hyperparams,
+ list[unknown] model)
+ return (list[unknown] gradients) {
+ gradients = model
+ print(toString(as.matrix(gradients["worker_err"])))
+}
+
+aggregation = function(list[unknown] model,
+ list[unknown] gradients,
+ list[unknown] hyperparams)
+ return (list[unknown] modelResult) {
+ modelResult = model
+}
+
+e2 = "element2"
+params = list(e2)
+
+modelList = list("model")
+
+# Use paramserv function
+modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="REMOTE_SPARK", utype="BSP", epochs=10, hyperparams=params, k=1)
+
+print(toString(as.matrix(modelList2[1])))
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-test.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-test.dml b/src/test/scripts/functions/paramserv/paramserv-test.dml
new file mode 100644
index 0000000..b21e9c0
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-test.dml
@@ -0,0 +1,48 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+
+# Generate the training data
+[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+n = nrow(images)
+
+# Generate the training data
+[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+
+# Split into training and validation
+val_size = n * 0.1
+X = images[(val_size+1):n,]
+X_val = images[1:val_size,]
+Y = labels[(val_size+1):n,]
+Y_val = labels[1:val_size,]
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, $epochs, $workers, $utype, $freq, $batchsize, $scheme, $mode)
+
+# Compute validation loss & accuracy
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, $batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
+loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+
+# Output results
+print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java
index d1b3a6d..a99035f 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/paramserv/ZPackageSuite.java
@@ -30,9 +30,11 @@ import org.junit.runners.Suite;
SparkDataPartitionerTest.class,
ParamservSyntaxTest.class,
SerializationTest.class,
+ RpcObjectTest.class,
ParamservRecompilationTest.class,
ParamservRuntimeNegativeTest.class,
- ParamservLocalNNTest.class
+ ParamservLocalNNTest.class,
+ ParamservSparkNNTest.class
})
[2/2] systemml git commit: [SYSTEMML-2420,
2422] New distributed paramserv spark workers and rpc
Posted by mb...@apache.org.
[SYSTEMML-2420,2422] New distributed paramserv spark workers and rpc
Closes #805.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/15ecb723
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/15ecb723
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/15ecb723
Branch: refs/heads/master
Commit: 15ecb723e39e3154412ca8f8824c4554ee64ca35
Parents: 54dbe9b
Author: EdgarLGB <gu...@atos.net>
Authored: Sat Jul 21 22:31:36 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sat Jul 21 22:31:36 2018 -0700
----------------------------------------------------------------------
.../controlprogram/paramserv/LocalPSWorker.java | 34 +++---
.../paramserv/LocalParamServer.java | 7 +-
.../controlprogram/paramserv/PSWorker.java | 15 ++-
.../controlprogram/paramserv/ParamServer.java | 39 ++++---
.../paramserv/ParamservUtils.java | 65 ++++++-----
.../paramserv/spark/SparkPSBody.java | 6 +-
.../paramserv/spark/SparkPSProxy.java | 68 +++++++++++
.../paramserv/spark/SparkPSWorker.java | 46 ++++++--
.../paramserv/spark/rpc/PSRpcCall.java | 97 ++++++++++++++++
.../paramserv/spark/rpc/PSRpcFactory.java | 57 ++++++++++
.../paramserv/spark/rpc/PSRpcHandler.java | 83 ++++++++++++++
.../paramserv/spark/rpc/PSRpcObject.java | 57 ++++++++++
.../paramserv/spark/rpc/PSRpcResponse.java | 112 +++++++++++++++++++
.../cp/ParamservBuiltinCPInstruction.java | 52 +++++++--
.../sysml/runtime/util/ProgramConverter.java | 11 +-
.../java/org/apache/sysml/utils/Statistics.java | 6 +
.../paramserv/ParamservLocalNNTest.java | 41 +++----
.../paramserv/ParamservSparkNNTest.java | 68 +++++++++--
.../functions/paramserv/RpcObjectTest.java | 56 ++++++++++
.../functions/paramserv/SerializationTest.java | 2 +-
.../paramserv/paramserv-nn-asp-batch.dml | 53 ---------
.../paramserv/paramserv-nn-asp-epoch.dml | 53 ---------
.../paramserv/paramserv-nn-bsp-batch-dc.dml | 53 ---------
.../paramserv/paramserv-nn-bsp-batch-dr.dml | 53 ---------
.../paramserv/paramserv-nn-bsp-batch-drr.dml | 53 ---------
.../paramserv/paramserv-nn-bsp-batch-or.dml | 53 ---------
.../paramserv/paramserv-nn-bsp-epoch.dml | 53 ---------
.../paramserv-spark-agg-service-failed.dml | 53 +++++++++
.../paramserv-spark-nn-bsp-batch-dc.dml | 53 ---------
.../paramserv/paramserv-spark-worker-failed.dml | 53 +++++++++
.../functions/paramserv/paramserv-test.dml | 48 ++++++++
.../functions/paramserv/ZPackageSuite.java | 4 +-
32 files changed, 961 insertions(+), 543 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index bbf2dbe..c23943d 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -35,6 +35,9 @@ import org.apache.sysml.utils.Statistics;
public class LocalPSWorker extends PSWorker implements Callable<Void> {
protected static final Log LOG = LogFactory.getLog(LocalPSWorker.class.getName());
+ private static final long serialVersionUID = 5195390748495357295L;
+
+ protected LocalPSWorker() {}
public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) {
@@ -42,6 +45,11 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
}
@Override
+ public String getWorkerName() {
+ return String.format("Local worker_%d", _workerID);
+ }
+
+ @Override
public Void call() throws Exception {
if (DMLScript.STATISTICS)
Statistics.incWorkerNumber();
@@ -60,10 +68,10 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
}
if (LOG.isDebugEnabled()) {
- LOG.debug(String.format("Local worker_%d: Job finished.", _workerID));
+ LOG.debug(String.format("%s: job finished.", getWorkerName()));
}
} catch (Exception e) {
- throw new DMLRuntimeException(String.format("Local worker_%d failed", _workerID), e);
+ throw new DMLRuntimeException(String.format("%s failed", getWorkerName()), e);
}
return null;
}
@@ -93,7 +101,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL);
if (LOG.isDebugEnabled()) {
- LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1));
+ LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1));
}
}
@@ -108,9 +116,9 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
Statistics.accPSLocalModelUpdateTime((long) tUpd.stop());
if (LOG.isDebugEnabled()) {
- LOG.debug(String.format("Local worker_%d: Local global parameter [size:%d kb] updated. "
+ LOG.debug(String.format("%s: local global parameter [size:%d kb] updated. "
+ "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]",
- _workerID, globalParams.getDataSize(), i + 1, _epochs, j + 1, totalIter));
+ getWorkerName(), globalParams.getDataSize(), i + 1, _epochs, j + 1, totalIter));
}
return globalParams;
}
@@ -129,17 +137,17 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL);
}
if (LOG.isDebugEnabled()) {
- LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1));
+ LOG.debug(String.format("%s: finished %d epoch.", getWorkerName(), i + 1));
}
}
}
private ListObject pullModel() {
// Pull the global parameters from ps
- ListObject globalParams = (ListObject)_ps.pull(_workerID);
+ ListObject globalParams = _ps.pull(_workerID);
if (LOG.isDebugEnabled()) {
- LOG.debug(String.format("Local worker_%d: Successfully pull the global parameters "
- + "[size:%d kb] from ps.", _workerID, globalParams.getDataSize() / 1024));
+ LOG.debug(String.format("%s: successfully pull the global parameters "
+ + "[size:%d kb] from ps.", getWorkerName(), globalParams.getDataSize() / 1024));
}
return globalParams;
}
@@ -148,8 +156,8 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
// Push the gradients to ps
_ps.push(_workerID, gradients);
if (LOG.isDebugEnabled()) {
- LOG.debug(String.format("Local worker_%d: Successfully push the gradients "
- + "[size:%d kb] to ps.", _workerID, gradients.getDataSize() / 1024));
+ LOG.debug(String.format("%s: successfully push the gradients "
+ + "[size:%d kb] to ps.", getWorkerName(), gradients.getDataSize() / 1024));
}
}
@@ -168,8 +176,8 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
_ec.setVariable(Statement.PS_LABELS, bLabels);
if (LOG.isDebugEnabled()) {
- LOG.debug(String.format("Local worker_%d: Got batch data [size:%d kb] of index from %d to %d [last index: %d]. "
- + "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", _workerID,
+ LOG.debug(String.format("%s: got batch data [size:%d kb] of index from %d to %d [last index: %d]. "
+ + "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", getWorkerName(),
bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, dataSize, i + 1, _epochs,
j + 1, totalIter));
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
index 52372c9..0c73acb 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -22,11 +22,14 @@ package org.apache.sysml.runtime.controlprogram.paramserv;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ListObject;
public class LocalParamServer extends ParamServer {
+ public LocalParamServer() {
+ super();
+ }
+
public LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
super(model, aggFunc, updateType, ec, workerNum);
}
@@ -37,7 +40,7 @@ public class LocalParamServer extends ParamServer {
}
@Override
- public Data pull(int workerID) {
+ public ListObject pull(int workerID) {
ListObject model;
try {
model = _modelMap.get(workerID).take();
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
index 1ab5f5e..464db9b 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -21,6 +21,7 @@ package org.apache.sysml.runtime.controlprogram.paramserv;
import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.PS_FUNC_PREFIX;
+import java.io.Serializable;
import java.util.ArrayList;
import java.util.stream.Collectors;
@@ -34,7 +35,10 @@ import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
-public abstract class PSWorker {
+public abstract class PSWorker implements Serializable {
+
+ private static final long serialVersionUID = -3510485051178200118L;
+
protected int _workerID;
protected int _epochs;
protected long _batchSize;
@@ -50,10 +54,8 @@ public abstract class PSWorker {
protected String _updFunc;
protected Statement.PSFrequency _freq;
- protected PSWorker() {
+ protected PSWorker() {}
- }
-
protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
MatrixObject valFeatures, MatrixObject valLabels, ExecutionContext ec, ParamServer ps) {
_workerID = workerID;
@@ -65,7 +67,10 @@ public abstract class PSWorker {
_valLabels = valLabels;
_ec = ec;
_ps = ps;
+ setupUpdateFunction(updFunc, ec);
+ }
+ protected void setupUpdateFunction(String updFunc, ExecutionContext ec) {
// Get the update function
String[] cfn = ParamservUtils.getCompleteFuncName(updFunc, PS_FUNC_PREFIX);
String ns = cfn[0];
@@ -125,4 +130,6 @@ public abstract class PSWorker {
public MatrixObject getLabels() {
return _labels;
}
+
+ public abstract String getWorkerName();
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
index bd8ee36..2607036 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -42,7 +42,6 @@ import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
-import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysml.runtime.instructions.cp.ListObject;
import org.apache.sysml.utils.Statistics;
@@ -53,17 +52,19 @@ public abstract class ParamServer
protected static final boolean ACCRUE_BSP_GRADIENTS = true;
// worker input queues and global model
- protected final Map<Integer, BlockingQueue<ListObject>> _modelMap;
+ protected Map<Integer, BlockingQueue<ListObject>> _modelMap;
private ListObject _model;
//aggregation service
- protected final ExecutionContext _ec;
- private final Statement.PSUpdateType _updateType;
- private final FunctionCallCPInstruction _inst;
- private final String _outputName;
- private final boolean[] _finishedStates; // Workers' finished states
+ protected ExecutionContext _ec;
+ private Statement.PSUpdateType _updateType;
+ private FunctionCallCPInstruction _inst;
+ private String _outputName;
+ private boolean[] _finishedStates; // Workers' finished states
private ListObject _accGradients = null;
+ protected ParamServer() {}
+
protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
// init worker queues and global model
_modelMap = new HashMap<>(workerNum);
@@ -77,10 +78,22 @@ public abstract class ParamServer
_ec = ec;
_updateType = updateType;
_finishedStates = new boolean[workerNum];
+ setupAggFunc(_ec, aggFunc);
+
+ // broadcast initial model
+ try {
+ broadcastModel();
+ }
+ catch (InterruptedException e) {
+ throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e);
+ }
+ }
+
+ public void setupAggFunc(ExecutionContext ec, String aggFunc) {
String[] cfn = ParamservUtils.getCompleteFuncName(aggFunc, PS_FUNC_PREFIX);
String ns = cfn[0];
String fname = cfn[1];
- FunctionProgramBlock func = _ec.getProgram().getFunctionProgramBlock(ns, fname);
+ FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(ns, fname);
ArrayList<DataIdentifier> inputs = func.getInputParams();
ArrayList<DataIdentifier> outputs = func.getOutputParams();
@@ -101,19 +114,11 @@ public abstract class ParamServer
ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
.collect(Collectors.toCollection(ArrayList::new));
_inst = new FunctionCallCPInstruction(ns, fname, boundInputs, inputNames, outputNames, "aggregate function");
-
- // broadcast initial model
- try {
- broadcastModel();
- }
- catch (InterruptedException e) {
- throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e);
- }
}
public abstract void push(int workerID, ListObject value);
- public abstract Data pull(int workerID);
+ public abstract ListObject pull(int workerID);
public ListObject getResult() {
// All the model updating work has terminated,
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
index b9fd7a8..cf27457 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -28,8 +28,11 @@ import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.sysml.api.DMLScript;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.MultiThreadedHop;
@@ -57,6 +60,7 @@ import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkAggregator;
import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkMapper;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ListObject;
@@ -68,13 +72,14 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.util.ProgramConverter;
+import org.apache.sysml.utils.Statistics;
import scala.Tuple2;
public class ParamservUtils {
+ protected static final Log LOG = LogFactory.getLog(ParamservUtils.class.getName());
public static final String PS_FUNC_PREFIX = "_ps_";
-
public static long SEED = -1; // Used for generating permutation
/**
@@ -140,6 +145,14 @@ public class ParamservUtils {
CacheableData<?> cd = (CacheableData<?>) data;
cd.enableCleanup(true);
ec.cleanupCacheableData(cd);
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(String.format("%s has been deleted.", cd.getFileName()));
+ }
+ }
+
+ public static void cleanupMatrixObject(ExecutionContext ec, MatrixObject mo) {
+ mo.enableCleanup(true);
+ ec.cleanupCacheableData(mo);
}
public static MatrixObject newMatrixObject(MatrixBlock mb) {
@@ -365,6 +378,7 @@ public class ParamservUtils {
@SuppressWarnings("unchecked")
public static JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> doPartitionOnSpark(SparkExecutionContext sec, MatrixObject features, MatrixObject labels, Statement.PSScheme scheme, int workerNum) {
+ Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
// Get input RDD
JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD = (JavaPairRDD<MatrixIndexes, MatrixBlock>)
sec.getRDDHandleForMatrixObject(features, InputInfo.BinaryBlockInputInfo);
@@ -372,33 +386,34 @@ public class ParamservUtils {
sec.getRDDHandleForMatrixObject(labels, InputInfo.BinaryBlockInputInfo);
DataPartitionerSparkMapper mapper = new DataPartitionerSparkMapper(scheme, workerNum, sec, (int) features.getNumRows());
- return ParamservUtils.assembleTrainingData(features.getNumRows(), featuresRDD, labelsRDD) // Combine features and labels into a pair (rowBlockID => (features, labels))
+ JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> result = ParamservUtils
+ .assembleTrainingData(features.getNumRows(), featuresRDD, labelsRDD) // Combine features and labels into a pair (rowBlockID => (features, labels))
.flatMapToPair(mapper) // Do the data partitioning on spark (workerID => (rowBlockID, (single row features, single row labels))
// Aggregate the partitioned matrix according to rowID for each worker
// i.e. (workerID => ordered list[(rowBlockID, (single row features, single row labels)]
- .aggregateByKey(new LinkedList<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>(),
- new Partitioner() {
- private static final long serialVersionUID = -7937781374718031224L;
- @Override
- public int getPartition(Object workerID) {
- return (int) workerID;
- }
- @Override
- public int numPartitions() {
- return workerNum;
- }
- },
- (list, input) -> {
- list.add(input);
- return list;
- },
- (l1, l2) -> {
- l1.addAll(l2);
- l1.sort((o1, o2) -> o1._1.compareTo(o2._1));
- return l1;
- })
- .mapToPair(new DataPartitionerSparkAggregator(
- features.getNumColumns(), labels.getNumColumns()));
+ .aggregateByKey(new LinkedList<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>(), new Partitioner() {
+ private static final long serialVersionUID = -7937781374718031224L;
+ @Override
+ public int getPartition(Object workerID) {
+ return (int) workerID;
+ }
+ @Override
+ public int numPartitions() {
+ return workerNum;
+ }
+ }, (list, input) -> {
+ list.add(input);
+ return list;
+ }, (l1, l2) -> {
+ l1.addAll(l2);
+ l1.sort((o1, o2) -> o1._1.compareTo(o2._1));
+ return l1;
+ })
+ .mapToPair(new DataPartitionerSparkAggregator(features.getNumColumns(), labels.getNumColumns()));
+
+ if (DMLScript.STATISTICS)
+ Statistics.accPSSetupTime((long) tSetup.stop());
+ return result;
}
public static ListObject accrueGradients(ListObject accGradients, ListObject gradients) {
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
index ec10232..9354025 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
@@ -28,12 +28,10 @@ public class SparkPSBody {
private ExecutionContext _ec;
- public SparkPSBody() {
-
- }
+ public SparkPSBody() {}
public SparkPSBody(ExecutionContext ec) {
- this._ec = ec;
+ _ec = ec;
}
public ExecutionContext getEc() {
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
new file mode 100644
index 0000000..de7b6c6
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark;
+
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PULL;
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PUSH;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.utils.Statistics;
+
+public class SparkPSProxy extends ParamServer {
+
+ private TransportClient _client;
+ private final long _rpcTimeout;
+
+ public SparkPSProxy(TransportClient client, long rpcTimeout) {
+ super();
+ _client = client;
+ _rpcTimeout = rpcTimeout;
+ }
+
+ @Override
+ public void push(int workerID, ListObject value) {
+ Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
+ PSRpcResponse response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, value).serialize(), _rpcTimeout));
+ if (DMLScript.STATISTICS)
+ Statistics.accPSRpcRequestTime((long) tRpc.stop());
+ if (!response.isSuccessful()) {
+ throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients. \n%s", workerID, response.getErrorMessage()));
+ }
+ }
+
+ @Override
+ public ListObject pull(int workerID) {
+ Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
+ PSRpcResponse response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, null).serialize(), _rpcTimeout));
+ if (DMLScript.STATISTICS)
+ Statistics.accPSRpcRequestTime((long) tRpc.stop());
+ if (!response.isSuccessful()) {
+ throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models. \n%s", workerID, response.getErrorMessage()));
+ }
+ return response.getResultModel();
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
index 466801f..fa06243 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
@@ -20,43 +20,58 @@
package org.apache.sysml.runtime.controlprogram.paramserv.spark;
import java.io.IOException;
-import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.api.java.function.VoidFunction;
+import org.apache.sysml.api.DMLScript;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.codegen.CodegenUtils;
-import org.apache.sysml.runtime.controlprogram.paramserv.PSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory;
import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.util.ProgramConverter;
+import org.apache.sysml.utils.Statistics;
import scala.Tuple2;
-public class SparkPSWorker extends PSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable {
+public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>> {
private static final long serialVersionUID = -8674739573419648732L;
private String _program;
private HashMap<String, byte[]> _clsMap;
+ private String _host; // host ip of driver
+ private long _rpcTimeout; // rpc ask timeout
+ private String _aggFunc;
- protected SparkPSWorker() {
- // No-args constructor used for deserialization
- }
-
- public SparkPSWorker(String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap) {
+ public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, String host, long rpcTimeout) {
_updFunc = updFunc;
+ _aggFunc = aggFunc;
_freq = freq;
_epochs = epochs;
_batchSize = batchSize;
_program = program;
_clsMap = clsMap;
+ _host = host;
+ _rpcTimeout = rpcTimeout;
+ }
+
+ @Override
+ public String getWorkerName() {
+ return String.format("Spark worker_%d", _workerID);
}
@Override
public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception {
+ Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
configureWorker(input);
+ if (DMLScript.STATISTICS)
+ Statistics.accPSSetupTime((long) tSetup.stop());
+ call(); // Launch the worker
}
private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws IOException {
@@ -73,5 +88,20 @@ public class SparkPSWorker extends PSWorker implements VoidFunction<Tuple2<Integ
// Initialize the buffer pool and register it in the jvm shutdown hook in order to be cleanuped at the end
RemoteParForUtils.setupBufferPool(_workerID);
+
+ // Create the ps proxy
+ _ps = PSRpcFactory.createSparkPSProxy(_host, _rpcTimeout);
+
+ // Initialize the update function
+ setupUpdateFunction(_updFunc, _ec);
+
+ // Initialize the agg function
+ _ps.setupAggFunc(_ec, _aggFunc);
+
+ // Lazy initialize the matrix of features and labels
+ setFeatures(ParamservUtils.newMatrixObject(input._2._1));
+ setLabels(ParamservUtils.newMatrixObject(input._2._2));
+ _features.enableCleanup(false);
+ _labels.enableCleanup(false);
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
new file mode 100644
index 0000000..999d409
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN;
+import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END;
+import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM;
+import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY;
+import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN;
+import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT;
+
+import java.nio.ByteBuffer;
+import java.util.StringTokenizer;
+
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.util.ProgramConverter;
+
+public class PSRpcCall extends PSRpcObject {
+
+ private static final String PS_RPC_CALL_BEGIN = CDATA_BEGIN + "PSRPCCALL" + LEVELIN;
+ private static final String PS_RPC_CALL_END = LEVELOUT + CDATA_END;
+
+ private String _method;
+ private int _workerID;
+ private ListObject _data;
+
+ public PSRpcCall(String method, int workerID, ListObject data) {
+ _method = method;
+ _workerID = workerID;
+ _data = data;
+ }
+
+ public PSRpcCall(ByteBuffer buffer) {
+ deserialize(buffer);
+ }
+
+ public void deserialize(ByteBuffer buffer) {
+ //FIXME: instead of shallow deserialize + read, we should do a deep deserialize of the matrix blocks.
+ String input = bufferToString(buffer);
+ //header elimination
+ input = input.substring(PS_RPC_CALL_BEGIN.length(), input.length() - PS_RPC_CALL_END.length()); //remove start/end
+ StringTokenizer st = new StringTokenizer(input, COMPONENTS_DELIM);
+
+ _method = st.nextToken();
+ _workerID = Integer.valueOf(st.nextToken());
+ String dataStr = st.nextToken();
+ _data = dataStr.equals(EMPTY) ? null :
+ (ListObject) ProgramConverter.parseDataObject(dataStr)[1];
+ }
+
+ public ByteBuffer serialize() {
+ //FIXME: instead of export+shallow serialize, we should do a deep serialize of the matrix blocks.
+ StringBuilder sb = new StringBuilder();
+ sb.append(PS_RPC_CALL_BEGIN);
+ sb.append(_method);
+ sb.append(COMPONENTS_DELIM);
+ sb.append(_workerID);
+ sb.append(COMPONENTS_DELIM);
+ if (_data == null) {
+ sb.append(EMPTY);
+ } else {
+ flushListObject(_data);
+ sb.append(ProgramConverter.serializeDataObject(DATA_KEY, _data));
+ }
+ sb.append(PS_RPC_CALL_END);
+ return ByteBuffer.wrap(sb.toString().getBytes());
+ }
+
+ public String getMethod() {
+ return _method;
+ }
+
+ public int getWorkerID() {
+ return _workerID;
+ }
+
+ public ListObject getData() {
+ return _data;
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
new file mode 100644
index 0000000..c8b4024
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import java.io.IOException;
+import java.util.Collections;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSProxy;
+
+//TODO should be able to configure the port by users
+public class PSRpcFactory {
+
+ private static final int PORT = 5055;
+ private static final String MODULE_NAME = "ps";
+
+ private static TransportContext createTransportContext(LocalParamServer ps) {
+ TransportConf conf = new TransportConf(MODULE_NAME, new SystemPropertyConfigProvider());
+ PSRpcHandler handler = new PSRpcHandler(ps);
+ return new TransportContext(conf, handler);
+ }
+
+ /**
+ * Create and start the server
+ * @return server
+ */
+ public static TransportServer createServer(LocalParamServer ps, String host) {
+ TransportContext context = createTransportContext(ps);
+ return context.createServer(host, PORT, Collections.emptyList());
+ }
+
+ public static SparkPSProxy createSparkPSProxy(String host, long rpcTimeout) throws IOException {
+ TransportContext context = createTransportContext(new LocalParamServer());
+ return new SparkPSProxy(context.createClientFactory().createClient(host, PORT), rpcTimeout);
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
new file mode 100644
index 0000000..3d73a37
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PULL;
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PUSH;
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.EMPTY_DATA;
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.ERROR;
+import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.SUCCESS;
+
+import java.nio.ByteBuffer;
+
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public final class PSRpcHandler extends RpcHandler {
+
+ private LocalParamServer _server;
+
+ protected PSRpcHandler(LocalParamServer server) {
+ _server = server;
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer buffer, RpcResponseCallback callback) {
+ PSRpcCall call = new PSRpcCall(buffer);
+ PSRpcResponse response = null;
+ switch (call.getMethod()) {
+ case PUSH:
+ try {
+ _server.push(call.getWorkerID(), call.getData());
+ response = new PSRpcResponse(SUCCESS, EMPTY_DATA);
+ } catch (DMLRuntimeException exception) {
+ response = new PSRpcResponse(ERROR, ExceptionUtils.getFullStackTrace(exception));
+ } finally {
+ callback.onSuccess(response.serialize());
+ }
+ break;
+ case PULL:
+ ListObject data;
+ try {
+ data = _server.pull(call.getWorkerID());
+ response = new PSRpcResponse(SUCCESS, data);
+ } catch (DMLRuntimeException exception) {
+ response = new PSRpcResponse(ERROR, ExceptionUtils.getFullStackTrace(exception));
+ } finally {
+ callback.onSuccess(response.serialize());
+ }
+ break;
+ default:
+ throw new DMLRuntimeException(String.format("Does not support the rpc call for method %s", call.getMethod()));
+ }
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return new OneForOneStreamManager();
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
new file mode 100644
index 0000000..c6d7fd3
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import java.nio.ByteBuffer;
+
+import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public abstract class PSRpcObject {
+
+ public static final String PUSH = "push";
+ public static final String PULL = "pull";
+ public static final String DATA_KEY = "data";
+ public static final String EMPTY_DATA = "";
+
+ public abstract void deserialize(ByteBuffer buffer);
+
+ public abstract ByteBuffer serialize();
+
+ /**
+ * Convert direct byte buffer to string
+ * @param buffer direct byte buffer
+ * @return string
+ */
+ protected String bufferToString(ByteBuffer buffer) {
+ byte[] result = new byte[buffer.limit()];
+ buffer.get(result, 0, buffer.limit());
+ return new String(result);
+ }
+
+ /**
+ * Flush the data into HDFS
+ * @param data list object
+ */
+ protected void flushListObject(ListObject data) {
+ data.getData().stream().filter(d -> d instanceof CacheableData)
+ .forEach(d -> ((CacheableData<?>) d).exportData());
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
new file mode 100644
index 0000000..998c523
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
+
+import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_BEGIN;
+import static org.apache.sysml.runtime.util.ProgramConverter.CDATA_END;
+import static org.apache.sysml.runtime.util.ProgramConverter.COMPONENTS_DELIM;
+import static org.apache.sysml.runtime.util.ProgramConverter.EMPTY;
+import static org.apache.sysml.runtime.util.ProgramConverter.LEVELIN;
+import static org.apache.sysml.runtime.util.ProgramConverter.LEVELOUT;
+
+import java.nio.ByteBuffer;
+import java.util.StringTokenizer;
+
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.util.ProgramConverter;
+
+public class PSRpcResponse extends PSRpcObject {
+
+ public static final int SUCCESS = 1;
+ public static final int ERROR = 2;
+
+ private static final String PS_RPC_RESPONSE_BEGIN = CDATA_BEGIN + "PSRPCRESPONSE" + LEVELIN;
+ private static final String PS_RPC_RESPONSE_END = LEVELOUT + CDATA_END;
+
+ private int _status;
+ private Object _data; // Could be list object or exception
+
+ public PSRpcResponse(ByteBuffer buffer) {
+ deserialize(buffer);
+ }
+
+ public PSRpcResponse(int status, Object data) {
+ _status = status;
+ _data = data;
+ }
+
+ public boolean isSuccessful() {
+ return _status == SUCCESS;
+ }
+
+ public String getErrorMessage() {
+ return (String) _data;
+ }
+
+ public ListObject getResultModel() {
+ return (ListObject) _data;
+ }
+
+ @Override
+ public void deserialize(ByteBuffer buffer) {
+ //FIXME: instead of shallow deserialize + read, we should do a deep deserialize of the matrix blocks.
+ String input = bufferToString(buffer);
+ //header elimination
+ input = input.substring(PS_RPC_RESPONSE_BEGIN.length(), input.length() - PS_RPC_RESPONSE_END.length()); //remove start/end
+ StringTokenizer st = new StringTokenizer(input, COMPONENTS_DELIM);
+
+ _status = Integer.valueOf(st.nextToken());
+ String data = st.nextToken();
+ switch (_status) {
+ case SUCCESS:
+ _data = data.equals(EMPTY) ? null :
+ ProgramConverter.parseDataObject(data)[1];
+ break;
+ case ERROR:
+ _data = data;
+ break;
+ }
+ }
+
+ @Override
+ public ByteBuffer serialize() {
+ //FIXME: instead of export+shallow serialize, we should do a deep serialize of the matrix blocks.
+
+ StringBuilder sb = new StringBuilder();
+ sb.append(PS_RPC_RESPONSE_BEGIN);
+ sb.append(_status);
+ sb.append(COMPONENTS_DELIM);
+ switch (_status) {
+ case SUCCESS:
+ if (_data.equals(EMPTY_DATA)) {
+ sb.append(EMPTY);
+ } else {
+ flushListObject((ListObject) _data);
+ sb.append(ProgramConverter.serializeDataObject(DATA_KEY, (ListObject) _data));
+ }
+ break;
+ case ERROR:
+ sb.append(_data.toString());
+ break;
+ }
+ sb.append(PS_RPC_RESPONSE_END);
+ return ByteBuffer.wrap(sb.toString().getBytes());
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 4e7a718..6133987 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -55,6 +55,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
+import org.apache.spark.network.server.TransportServer;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.lops.LopProperties;
@@ -71,6 +72,7 @@ import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSBody;
import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.matrix.operators.Operator;
@@ -114,16 +116,16 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
}
private void runOnSpark(SparkExecutionContext sec, PSModeType mode) {
+ Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
+
PSScheme scheme = getScheme();
int workerNum = getWorkerNum(mode);
String updFunc = getParam(PS_UPDATE_FUN);
String aggFunc = getParam(PS_AGGREGATION_FUN);
- int k = getParLevel(workerNum);
-
// Get the compiled execution context
LocalVariableMap newVarsMap = createVarsMap(sec);
- ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, k);
+ ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1); // level of par is 1 in spark backend
MatrixObject features = sec.getMatrixObject(getParam(PS_FEATURES));
MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS));
@@ -131,16 +133,47 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
// Force all the instructions to CP type
Recompiler.recompileProgramBlockHierarchy2Forced(
newEC.getProgram().getProgramBlocks(), 0, new HashSet<>(), LopProperties.ExecType.CP);
-
+
// Serialize all the needed params for remote workers
SparkPSBody body = new SparkPSBody(newEC);
HashMap<String, byte[]> clsMap = new HashMap<>();
String program = ProgramConverter.serializeSparkPSBody(body, clsMap);
- SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getFrequency(), getEpochs(), getBatchSize(), program, clsMap);
- ParamservUtils.doPartitionOnSpark(sec, features, labels, scheme, workerNum) // Do data partitioning
- .foreach(worker); // Run remote workers
+ // Get some configurations
+ String host = sec.getSparkContext().getConf().get("spark.driver.host");
+ long rpcTimeout = sec.getSparkContext().getConf().contains("spark.rpc.askTimeout") ?
+ sec.getSparkContext().getConf().getTimeAsMs("spark.rpc.askTimeout") :
+ sec.getSparkContext().getConf().getTimeAsMs("spark.network.timeout", "120s");
+
+ // Create remote workers
+ SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), getFrequency(),
+ getEpochs(), getBatchSize(), program, clsMap, host, rpcTimeout);
+
+ // Create the agg service's execution context
+ ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0);
+
+ // Create the parameter server
+ ListObject model = sec.getListObject(getParam(PS_MODEL));
+ ParamServer ps = createPS(mode, aggFunc, getUpdateType(), workerNum, model, aggServiceEC);
+
+ if (DMLScript.STATISTICS)
+ Statistics.accPSSetupTime((long) tSetup.stop());
+
+ // Create the netty server for ps
+ TransportServer server = PSRpcFactory.createServer((LocalParamServer) ps, host); // Start the server
+ try {
+ ParamservUtils.doPartitionOnSpark(sec, features, labels, scheme, workerNum) // Do data partitioning
+ .foreach(worker); // Run remote workers
+ } catch (Exception e) {
+ throw new DMLRuntimeException("Paramserv function failed: ", e);
+ } finally {
+ // Stop the netty server
+ server.close();
+ }
+
+ // Fetch the final model from ps
+ sec.setVariable(output.getName(), ps.getResult());
}
private void runLocally(ExecutionContext ec, PSModeType mode) {
@@ -176,8 +209,8 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
MatrixObject valFeatures = ec.getMatrixObject(getParam(PS_VAL_FEATURES));
MatrixObject valLabels = ec.getMatrixObject(getParam(PS_VAL_LABELS));
List<LocalPSWorker> workers = IntStream.range(0, workerNum)
- .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, epochs, getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps))
- .collect(Collectors.toList());
+ .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, epochs, getBatchSize(), valFeatures, valLabels, workerECs.get(i), ps))
+ .collect(Collectors.toList());
// Do data partition
PSScheme scheme = getScheme();
@@ -296,6 +329,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
private ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) {
switch (mode) {
case LOCAL:
+ case REMOTE_SPARK:
return new LocalParamServer(model, aggFunc, updateType, ec, workerNum);
default:
throw new DMLRuntimeException("Unsupported parameter server: "+mode.name());
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
index 1d2115e..fc9d9b4 100644
--- a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
@@ -143,7 +143,7 @@ public class ProgramConverter
public static final String PB_IF = " IF" + LEVELIN;
public static final String PB_FC = " FC" + LEVELIN;
public static final String PB_EFC = " EFC" + LEVELIN;
-
+
public static final String CONF_STATS = "stats";
// Used for parfor
@@ -716,9 +716,10 @@ public class ProgramConverter
builder.append(rSerializeProgramBlocks(ec.getProgram().getProgramBlocks(), clsMap));
builder.append(PBS_END);
builder.append(NEWLINE);
+ builder.append(COMPONENTS_DELIM);
+ builder.append(NEWLINE);
builder.append(PSBODY_END);
-
return builder.toString();
}
@@ -868,7 +869,7 @@ public class ProgramConverter
value = mo.getFileName();
PartitionFormat partFormat = (mo.getPartitionFormat()!=null) ? new PartitionFormat(
mo.getPartitionFormat(),mo.getPartitionSize()) : PartitionFormat.NONE;
- metaData = new String[9];
+ metaData = new String[11];
metaData[0] = String.valueOf( mc.getRows() );
metaData[1] = String.valueOf( mc.getCols() );
metaData[2] = String.valueOf( mc.getRowsPerBlock() );
@@ -878,6 +879,8 @@ public class ProgramConverter
metaData[6] = OutputInfo.outputInfoToString( md.getOutputInfo() );
metaData[7] = String.valueOf( partFormat );
metaData[8] = String.valueOf( mo.getUpdateType() );
+ metaData[9] = String.valueOf(mo.isHDFSFileExists());
+ metaData[10] = String.valueOf(mo.isCleanupEnabled());
break;
case LIST:
// SCHEMA: <name>|<datatype>|<valuetype>|value|<metadata>|<tab>element1<tab>element2<tab>element3 (this is the list)
@@ -1683,6 +1686,8 @@ public class ProgramConverter
if( partFormat._dpf != PDataPartitionFormat.NONE )
mo.setPartitioned( partFormat._dpf, partFormat._N );
mo.setUpdateType(inplace);
+ mo.setHDFSFileExists(Boolean.valueOf(st.nextToken()));
+ mo.enableCleanup(Boolean.valueOf(st.nextToken()));
dat = mo;
break;
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/main/java/org/apache/sysml/utils/Statistics.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java
index 8f0d853..1dd8362 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -125,6 +125,7 @@ public class Statistics
private static final LongAdder psLocalModelUpdateTime = new LongAdder();
private static final LongAdder psModelBroadcastTime = new LongAdder();
private static final LongAdder psBatchIndexTime = new LongAdder();
+ private static final LongAdder psRpcRequestTime = new LongAdder();
//PARFOR optimization stats (low frequency updates)
private static long parforOptTime = 0; //in milli sec
@@ -564,6 +565,10 @@ public class Statistics
psBatchIndexTime.add(t);
}
+ public static void accPSRpcRequestTime(long t) {
+ psRpcRequestTime.add(t);
+ }
+
public static String getCPHeavyHitterCode( Instruction inst )
{
String opcode = null;
@@ -1003,6 +1008,7 @@ public class Statistics
psLocalModelUpdateTime.doubleValue() / 1000, psAggregationTime.doubleValue() / 1000));
sb.append(String.format("Paramserv model broadcast time:\t%.3f secs.\n", psModelBroadcastTime.doubleValue() / 1000));
sb.append(String.format("Paramserv batch slice time:\t%.3f secs.\n", psBatchIndexTime.doubleValue() / 1000));
+ sb.append(String.format("Paramserv RPC request time:\t%.3f secs.\n", psRpcRequestTime.doubleValue() / 1000));
}
if( parforOptCount>0 ){
sb.append("ParFor loops optimized:\t\t" + getParforOptCount() + ".\n");
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java
index d5fd509..905bfd1 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservLocalNNTest.java
@@ -19,75 +19,66 @@
package org.apache.sysml.test.integration.functions.paramserv;
+import org.apache.sysml.parser.Statement;
import org.apache.sysml.test.integration.AutomatedTestBase;
import org.apache.sysml.test.integration.TestConfiguration;
import org.junit.Test;
public class ParamservLocalNNTest extends AutomatedTestBase {
- private static final String TEST_NAME1 = "paramserv-nn-bsp-batch-dc";
- private static final String TEST_NAME2 = "paramserv-nn-asp-batch";
- private static final String TEST_NAME3 = "paramserv-nn-bsp-epoch";
- private static final String TEST_NAME4 = "paramserv-nn-asp-epoch";
- private static final String TEST_NAME5 = "paramserv-nn-bsp-batch-drr";
- private static final String TEST_NAME6 = "paramserv-nn-bsp-batch-dr";
- private static final String TEST_NAME7 = "paramserv-nn-bsp-batch-or";
+ private static final String TEST_NAME = "paramserv-test";
private static final String TEST_DIR = "functions/paramserv/";
private static final String TEST_CLASS_DIR = TEST_DIR + ParamservLocalNNTest.class.getSimpleName() + "/";
@Override
public void setUp() {
- addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {}));
- addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {}));
- addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {}));
- addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {}));
- addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {}));
- addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {}));
- addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {}));
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {}));
}
@Test
public void testParamservBSPBatchDisjointContiguous() {
- runDMLTest(TEST_NAME1);
+ runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
}
@Test
public void testParamservASPBatch() {
- runDMLTest(TEST_NAME2);
+ runDMLTest(10, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
}
@Test
public void testParamservBSPEpoch() {
- runDMLTest(TEST_NAME3);
+ runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
}
@Test
public void testParamservASPEpoch() {
- runDMLTest(TEST_NAME4);
+ runDMLTest(10, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS);
}
@Test
public void testParamservBSPBatchDisjointRoundRobin() {
- runDMLTest(TEST_NAME5);
+ runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_ROUND_ROBIN);
}
@Test
public void testParamservBSPBatchDisjointRandom() {
- runDMLTest(TEST_NAME6);
+ runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_RANDOM);
}
@Test
public void testParamservBSPBatchOverlapReshuffle() {
- runDMLTest(TEST_NAME7);
+ runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 32, Statement.PSScheme.OVERLAP_RESHUFFLE);
}
- private void runDMLTest(String testname) {
- TestConfiguration config = getTestConfiguration(testname);
+ private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme) {
+ TestConfiguration config = getTestConfiguration(ParamservLocalNNTest.TEST_NAME);
loadTestConfiguration(config);
- programArgs = new String[] { "-explain" };
+ programArgs = new String[] { "-explain", "-nvargs", "mode=LOCAL", "epochs=" + epochs,
+ "workers=" + workers, "utype=" + utype, "freq=" + freq, "batchsize=" + batchsize,
+ "scheme=" + scheme };
String HOME = SCRIPT_DIR + TEST_DIR;
- fullDMLScriptName = HOME + testname + ".dml";
+ fullDMLScriptName = HOME + ParamservLocalNNTest.TEST_NAME + ".dml";
runTest(true, false, null, null, -1);
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
index 2441116..30eccb3 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservSparkNNTest.java
@@ -1,14 +1,24 @@
package org.apache.sysml.test.integration.functions.paramserv;
+import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaSparkContext;
import org.apache.sysml.api.DMLException;
import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.Script;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.test.integration.AutomatedTestBase;
import org.apache.sysml.test.integration.TestConfiguration;
import org.junit.Test;
public class ParamservSparkNNTest extends AutomatedTestBase {
- private static final String TEST_NAME1 = "paramserv-spark-nn-bsp-batch-dc";
+ private static final String TEST_NAME1 = "paramserv-test";
+ private static final String TEST_NAME2 = "paramserv-spark-worker-failed";
+ private static final String TEST_NAME3 = "paramserv-spark-agg-service-failed";
private static final String TEST_DIR = "functions/paramserv/";
private static final String TEST_CLASS_DIR = TEST_DIR + ParamservSparkNNTest.class.getSimpleName() + "/";
@@ -16,14 +26,42 @@ public class ParamservSparkNNTest extends AutomatedTestBase {
@Override
public void setUp() {
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {}));
+ addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {}));
+ addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {}));
}
@Test
public void testParamservBSPBatchDisjointContiguous() {
- runDMLTest(TEST_NAME1);
+ runDMLTest(2, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+ }
+
+ @Test
+ public void testParamservASPBatchDisjointContiguous() {
+ runDMLTest(2, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.BATCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+ }
+
+ @Test
+ public void testParamservBSPEpochDisjointContiguous() {
+ runDMLTest(10, 3, Statement.PSUpdateType.BSP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
+ }
+
+ @Test
+ public void testParamservASPEpochDisjointContiguous() {
+ runDMLTest(10, 3, Statement.PSUpdateType.ASP, Statement.PSFrequency.EPOCH, 16, Statement.PSScheme.DISJOINT_CONTIGUOUS);
}
- private void runDMLTest(String testname) {
+ @Test
+ public void testParamservWorkerFailed() {
+ runDMLTest(TEST_NAME2, true, DMLException.class, "Invalid indexing by name in unnamed list: worker_err.");
+ }
+
+ @Test
+ public void testParamservAggServiceFailed() {
+ runDMLTest(TEST_NAME3, true, DMLException.class, "Invalid indexing by name in unnamed list: agg_service_err.");
+ }
+
+ private void runDMLTest(String testname, boolean exceptionExpected, Class<?> expectedException, String errMessage) {
+ programArgs = new String[] { "-explain" };
DMLScript.RUNTIME_PLATFORM oldRtplatform = AutomatedTestBase.rtplatform;
boolean oldUseLocalSparkConfig = DMLScript.USE_LOCAL_SPARK_CONFIG;
AutomatedTestBase.rtplatform = DMLScript.RUNTIME_PLATFORM.SPARK;
@@ -32,16 +70,32 @@ public class ParamservSparkNNTest extends AutomatedTestBase {
try {
TestConfiguration config = getTestConfiguration(testname);
loadTestConfiguration(config);
- programArgs = new String[] { "-explain" };
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- // The test is not already finished, so it is normal to have the NPE
- runTest(true, true, DMLException.class, null, -1);
+ runTest(true, exceptionExpected, expectedException, errMessage, -1);
} finally {
AutomatedTestBase.rtplatform = oldRtplatform;
DMLScript.USE_LOCAL_SPARK_CONFIG = oldUseLocalSparkConfig;
}
-
}
+ private void runDMLTest(int epochs, int workers, Statement.PSUpdateType utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme) {
+ Script script = dmlFromFile(SCRIPT_DIR + TEST_DIR + TEST_NAME1 + ".dml").in("$mode", Statement.PSModeType.REMOTE_SPARK.toString())
+ .in("$epochs", String.valueOf(epochs))
+ .in("$workers", String.valueOf(workers))
+ .in("$utype", utype.toString())
+ .in("$freq", freq.toString())
+ .in("$batchsize", String.valueOf(batchsize))
+ .in("$scheme", scheme.toString());
+
+ SparkConf conf = SparkExecutionContext.createSystemMLSparkConf().setAppName("ParamservSparkNNTest").setMaster("local[*]")
+ .set("spark.driver.allowMultipleContexts", "true");
+ JavaSparkContext sc = new JavaSparkContext(conf);
+ MLContext ml = new MLContext(sc);
+ ml.setStatistics(true);
+ ml.execute(script);
+ ml.resetConfig();
+ sc.stop();
+ ml.close();
+ }
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
new file mode 100644
index 0000000..57e1106
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.paramserv;
+
+import java.util.Arrays;
+
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
+import org.apache.sysml.runtime.instructions.cp.IntObject;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RpcObjectTest {
+
+ @Test
+ public void testPSRpcCall() {
+ MatrixObject mo1 = SerializationTest.generateDummyMatrix(10);
+ MatrixObject mo2 = SerializationTest.generateDummyMatrix(20);
+ IntObject io = new IntObject(30);
+ ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
+ PSRpcCall expected = new PSRpcCall(PSRpcObject.PUSH, 1, lo);
+ PSRpcCall actual = new PSRpcCall(expected.serialize());
+ Assert.assertEquals(new String(expected.serialize().array()), new String(actual.serialize().array()));
+ }
+
+ @Test
+ public void testPSRpcResponse() {
+ MatrixObject mo1 = SerializationTest.generateDummyMatrix(10);
+ MatrixObject mo2 = SerializationTest.generateDummyMatrix(20);
+ IntObject io = new IntObject(30);
+ ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
+ PSRpcResponse expected = new PSRpcResponse(PSRpcResponse.SUCCESS, lo);
+ PSRpcResponse actual = new PSRpcResponse(expected.serialize());
+ Assert.assertEquals(new String(expected.serialize().array()), new String(actual.serialize().array()));
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java
index 2a08ca6..64d6492 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SerializationTest.java
@@ -68,7 +68,7 @@ public class SerializationTest {
Assert.assertEquals(io.getLongValue(), actualIO.getLongValue());
}
- private MatrixObject generateDummyMatrix(int size) {
+ public static MatrixObject generateDummyMatrix(int size) {
double[] dl = new double[size];
for (int i = 0; i < size; i++) {
dl[i] = i;
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
deleted file mode 100644
index ba22942..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-batch.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "BATCH", batchsize,"DISJOINT_CONTIGUOUS", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml b/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
deleted file mode 100644
index c8c6a2f..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-asp-epoch.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "ASP", "EPOCH", batchsize, "DISJOINT_CONTIGUOUS", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml
deleted file mode 100644
index 78fc1c4..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dc.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_CONTIGUOUS", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml
deleted file mode 100644
index 9191b5a..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-dr.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_RANDOM", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml
deleted file mode 100644
index ec18cb4..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-drr.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 4
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "DISJOINT_ROUND_ROBIN", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/15ecb723/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml b/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml
deleted file mode 100644
index 928dde2..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-nn-bsp-batch-or.dml
+++ /dev/null
@@ -1,53 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
-source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
-
-# Generate the training data
-[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-n = nrow(images)
-
-# Generate the training data
-[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
-
-# Split into training and validation
-val_size = n * 0.1
-X = images[(val_size+1):n,]
-X_val = images[1:val_size,]
-Y = labels[(val_size+1):n,]
-Y_val = labels[1:val_size,]
-
-# Arguments
-epochs = 10
-workers = 2
-batchsize = 32
-
-# Train
-[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers, "BSP", "BATCH", batchsize, "OVERLAP_RESHUFFLE", "LOCAL")
-
-# Compute validation loss & accuracy
-probs_val = mnist_lenet::predict(X_val, C, Hin, Win, batchsize, W1, b1, W2, b2, W3, b3, W4, b4)
-loss_val = cross_entropy_loss::forward(probs_val, Y_val)
-accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
-
-# Output results
-print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file