You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/05/29 06:21:23 UTC
[1/2] systemml git commit: [SYSTEMML-2085] Initial version of local
backend for paramserv builtin
Repository: systemml
Updated Branches:
refs/heads/master c7a9e016d -> 97018d4e6
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml b/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml
new file mode 100644
index 0000000..4d23b8c
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-minimum-version.dml
@@ -0,0 +1,52 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml") as mnist_lenet
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+
+# Generate the training data
+[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+n = nrow(images)
+
+# Generate the training data
+[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+
+# Split into training and validation
+val_size = n * 0.1
+X = images[(val_size+1):n,]
+X_val = images[1:val_size,]
+Y = labels[(val_size+1):n,]
+Y_val = labels[1:val_size,]
+
+# Arguments
+epochs = 10
+workers = 2
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers)
+
+# Compute validation loss & accuracy
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4)
+loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+
+# Output results
+print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/paramserv-miss-args.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-miss-args.dml b/src/test/scripts/functions/paramserv/paramserv-miss-args.dml
index f3a2c91..6ceb9ad 100644
--- a/src/test/scripts/functions/paramserv/paramserv-miss-args.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-miss-args.dml
@@ -20,7 +20,7 @@
#-------------------------------------------------------------
e1 = "element1"
-modelList = list(e1)
+modelList = list(e1=e1)
X = matrix(1, rows=2, cols=3)
Y = matrix(2, rows=2, cols=3)
X_val = matrix(3, rows=2, cols=3)
@@ -35,7 +35,7 @@ aggregation = function (matrix[double] input) return (matrix[double] output) {
}
e2 = "element2"
-params = list(e2)
+params = list(e2=e2)
# Use paramserv function
# Miss "features" parameterized argument
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/paramserv-nn-test.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-nn-test.dml b/src/test/scripts/functions/paramserv/paramserv-nn-test.dml
new file mode 100644
index 0000000..740a208
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-nn-test.dml
@@ -0,0 +1,52 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as mnist_lenet
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+
+# Generate the training data
+[images, labels, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+n = nrow(images)
+
+# Generate the training data
+[X, Y, C, Hin, Win] = mnist_lenet::generate_dummy_data()
+
+# Split into training and validation
+val_size = n * 0.1
+X = images[(val_size+1):n,]
+X_val = images[1:val_size,]
+Y = labels[(val_size+1):n,]
+Y_val = labels[1:val_size,]
+
+# Arguments
+epochs = 10
+workers = 2
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, Y, X_val, Y_val, C, Hin, Win, epochs, workers)
+
+# Compute validation loss & accuracy
+probs_val = mnist_lenet::predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4)
+loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+
+# Output results
+print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml b/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml
index c504303..6d06ce2 100644
--- a/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-without-optional-args.dml
@@ -20,7 +20,7 @@
#-------------------------------------------------------------
e1 = "element1"
-modelList = list(e1)
+modelList = list(e1=e1)
X = matrix(1, rows=2, cols=3)
Y = matrix(2, rows=2, cols=3)
X_val = matrix(3, rows=2, cols=3)
@@ -35,7 +35,7 @@ aggregation = function (matrix[double] input) return (matrix[double] output) {
}
e2 = "element2"
-params = list(e2)
+params = list(e2=e2)
# Use paramserv function
# Remove the optional "hyperparams"
[2/2] systemml git commit: [SYSTEMML-2085] Initial version of local
backend for paramserv builtin
Posted by mb...@apache.org.
[SYSTEMML-2085] Initial version of local backend for paramserv builtin
Closes #771.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/97018d4e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/97018d4e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/97018d4e
Branch: refs/heads/master
Commit: 97018d4e688ba7eeaaa4567ca1e174a3c5525468
Parents: c7a9e01
Author: EdgarLGB <gu...@atos.net>
Authored: Mon May 28 23:17:18 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Mon May 28 23:17:20 2018 -0700
----------------------------------------------------------------------
.../ParameterizedBuiltinFunctionExpression.java | 6 +-
.../java/org/apache/sysml/parser/Statement.java | 1 +
.../context/ExecutionContext.java | 13 +-
.../controlprogram/paramserv/LocalPSWorker.java | 97 +++++
.../paramserv/LocalParamServer.java | 59 +++
.../controlprogram/paramserv/PSWorker.java | 131 +++++++
.../controlprogram/paramserv/ParamServer.java | 232 +++++++++++
.../paramserv/ParamservUtils.java | 97 +++++
.../runtime/instructions/cp/CPOperand.java | 2 +-
.../runtime/instructions/cp/ListObject.java | 14 +
.../cp/MatrixIndexingCPInstruction.java | 4 +-
.../cp/ParamservBuiltinCPInstruction.java | 257 ++++++++++++-
.../test/integration/AutomatedTestBase.java | 18 +-
.../functions/paramserv/ParamservFuncTest.java | 29 +-
.../paramserv/mnist_lenet_paramserv.dml | 383 +++++++++++++++++++
.../mnist_lenet_paramserv_minimum_version.dml | 377 ++++++++++++++++++
.../functions/paramserv/paramserv-all-args.dml | 4 +-
.../functions/paramserv/paramserv-ipa-test.dml | 47 ---
.../paramserv/paramserv-minimum-version.dml | 52 +++
.../functions/paramserv/paramserv-miss-args.dml | 4 +-
.../functions/paramserv/paramserv-nn-test.dml | 52 +++
.../paramserv-without-optional-args.dml | 4 +-
22 files changed, 1805 insertions(+), 78 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
index 3d74f8d..99aec78 100644
--- a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
@@ -341,12 +341,12 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
.collect(Collectors.toSet());
checkStringParam(false, fname, Statement.PS_UPDATE_TYPE, utypes, conditional);
Set<String> frequencies = Arrays.stream(Statement.PSFrequency.values()).map(Enum::name).collect(Collectors.toSet());
- checkStringParam(false, fname, Statement.PS_FREQUENCY, frequencies, conditional);
+ checkStringParam(true, fname, Statement.PS_FREQUENCY, frequencies, conditional);
checkDataValueType(false, fname, Statement.PS_EPOCHS, DataType.SCALAR, ValueType.INT, conditional);
checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, DataType.SCALAR, ValueType.INT, conditional);
- checkDataValueType(false, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT, conditional);
+ checkDataValueType(true, fname, Statement.PS_PARALLELISM, DataType.SCALAR, ValueType.INT, conditional);
Set<String> schemes = Arrays.stream(Statement.PSScheme.values()).map(Enum::name).collect(Collectors.toSet());
- checkStringParam(false, fname, Statement.PS_SCHEME, schemes, conditional);
+ checkStringParam(true, fname, Statement.PS_SCHEME, schemes, conditional);
checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, DataType.LIST, ValueType.UNKNOWN, conditional);
Set<String> checkpointings = Arrays.stream(Statement.PSCheckpointing.values()).map(Enum::name).collect(Collectors.toSet());
checkStringParam(true, fname, Statement.PS_CHECKPOINTING, checkpointings, conditional);
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/parser/Statement.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Statement.java b/src/main/java/org/apache/sysml/parser/Statement.java
index 4853a47..1987d31 100644
--- a/src/main/java/org/apache/sysml/parser/Statement.java
+++ b/src/main/java/org/apache/sysml/parser/Statement.java
@@ -71,6 +71,7 @@ public abstract class Statement implements ParseInfo
public static final String PS_UPDATE_FUN = "upd";
public static final String PS_AGGREGATION_FUN = "agg";
public static final String PS_MODE = "mode";
+ public static final String PS_GRADIENTS = "gradients";
public enum PSModeType {
LOCAL, REMOTE_SPARK
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
index 67b2a83..6807848 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java
@@ -42,6 +42,7 @@ import org.apache.sysml.runtime.instructions.cp.CPInstruction;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory;
import org.apache.sysml.runtime.instructions.gpu.context.GPUContext;
@@ -443,7 +444,17 @@ public class ExecutionContext {
public void setScalarOutput(String varName, ScalarObject so) {
setVariable(varName, so);
}
-
+
+ public ListObject getListObject(String name) {
+ Data dat = getVariable(name);
+ //error handling if non existing or no list
+ if (dat == null)
+ throw new DMLRuntimeException("Variable '" + name + "' does not exist in the symbol table.");
+ if (!(dat instanceof ListObject))
+ throw new DMLRuntimeException("Variable '" + name + "' is not a list.");
+ return (ListObject) dat;
+ }
+
public void releaseMatrixOutputForGPUInstruction(String varName) {
MatrixObject mo = getMatrixObject(varName);
if(mo.getGPUObject(getGPUContext(0)) == null || !mo.getGPUObject(getGPUContext(0)).isAllocated()) {
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
new file mode 100644
index 0000000..181b866
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public class LocalPSWorker extends PSWorker implements Runnable {
+
+ protected static final Log LOG = LogFactory.getLog(LocalPSWorker.class.getName());
+
+ public LocalPSWorker(long workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
+ ListObject hyperParams, ExecutionContext ec, ParamServer ps) {
+ super(workerID, updFunc, freq, epochs, batchSize, hyperParams, ec, ps);
+ }
+
+ @Override
+ public void run() {
+
+ long dataSize = _features.getNumRows();
+
+ for (int i = 0; i < _epochs; i++) {
+ int totalIter = (int) Math.ceil(dataSize / _batchSize);
+ for (int j = 0; j < totalIter; j++) {
+ // Pull the global parameters from ps
+ // Need to copy the global parameter
+ ListObject globalParams = ParamservUtils.copyList((ListObject) _ps.pull(_workerID));
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(String.format(
+ "Local worker_%d: Successfully pull the global parameters [size:%d kb] from ps.", _workerID,
+ globalParams.getDataSize() / 1024));
+ }
+ _ec.setVariable(Statement.PS_MODEL, globalParams);
+
+ long begin = j * _batchSize + 1;
+ long end = Math.min(begin + _batchSize, dataSize);
+
+ // Get batch features and labels
+ MatrixObject bFeatures = ParamservUtils.sliceMatrix(_features, begin, end);
+ MatrixObject bLabels = ParamservUtils.sliceMatrix(_labels, begin, end);
+ _ec.setVariable(Statement.PS_FEATURES, bFeatures);
+ _ec.setVariable(Statement.PS_LABELS, bLabels);
+
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(String.format(
+ "Local worker_%d: Got batch data [size:%d kb] of index from %d to %d. [Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]",
+ _workerID, bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, i + 1,
+ _epochs, j + 1, totalIter));
+ }
+
+ // Invoke the update function
+ _inst.processInstruction(_ec);
+
+ // Get the gradients
+ ListObject gradients = (ListObject) _ec.getVariable(_outputs.get(0).getName());
+
+ // Push the gradients to ps
+ _ps.push(_workerID, gradients);
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(String.format("Local worker_%d: Successfully push the gradients [size:%d kb] to ps.",
+ _workerID, gradients.getDataSize() / 1024));
+ }
+
+ ParamservUtils.cleanupListObject(_ec, globalParams);
+ ParamservUtils.cleanupData(bFeatures);
+ ParamservUtils.cleanupData(bLabels);
+ }
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(String.format("Local worker_%d: Finished %d epoch.", _workerID, i + 1));
+ }
+ }
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(String.format("Local worker_%d: Job finished.", _workerID));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
new file mode 100644
index 0000000..d060a91
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public class LocalParamServer extends ParamServer {
+
+ public LocalParamServer(ListObject model, String aggFunc, Statement.PSFrequency freq,
+ Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum,
+ ListObject hyperParams) {
+ super(model, aggFunc, freq, updateType, ec, workerNum, hyperParams);
+ }
+
+ @Override
+ public void push(long workerID, ListObject gradients) {
+ synchronized (_lock) {
+ _queue.add(new Gradient(workerID, gradients));
+ _lock.notifyAll();
+ }
+ }
+
+ @Override
+ public Data pull(long workerID) {
+ synchronized (_lock) {
+ while (getPulledState((int) workerID)) {
+ try {
+ _lock.wait();
+ } catch (InterruptedException e) {
+ throw new DMLRuntimeException(
+ String.format("Local worker_%d: failed to pull the global parameters.", workerID), e);
+ }
+ }
+ setPulledState((int) workerID, true);
+ }
+ return getResult();
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
new file mode 100644
index 0000000..9ace823
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import java.util.ArrayList;
+import java.util.stream.Collectors;
+
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.DataIdentifier;
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+@SuppressWarnings("unused")
+public abstract class PSWorker {
+
+ long _workerID = -1;
+ int _epochs;
+ long _batchSize;
+ MatrixObject _features;
+ MatrixObject _labels;
+ ExecutionContext _ec;
+ ParamServer _ps;
+ private String _updFunc;
+ private Statement.PSFrequency _freq;
+ private MatrixObject _valFeatures;
+ private MatrixObject _valLabels;
+
+ ArrayList<DataIdentifier> _outputs;
+ FunctionCallCPInstruction _inst;
+
+ public PSWorker(long workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize,
+ ListObject hyperParams, ExecutionContext ec, ParamServer ps) {
+ this._workerID = workerID;
+ this._updFunc = updFunc;
+ this._freq = freq;
+ this._epochs = epochs;
+ this._batchSize = batchSize;
+ this._ec = ExecutionContextFactory.createContext(ec.getProgram());
+ if (hyperParams != null) {
+ this._ec.setVariable(Statement.PS_HYPER_PARAMS, hyperParams);
+ }
+ this._ps = ps;
+
+ // Get the update function
+ String[] keys = DMLProgram.splitFunctionKey(updFunc);
+ String _funcName = keys[0];
+ String _funcNS = null;
+ if (keys.length == 2) {
+ _funcNS = keys[0];
+ _funcName = keys[1];
+ }
+ FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(_funcNS, _funcName);
+ ArrayList<DataIdentifier> _inputs = func.getInputParams();
+ _outputs = func.getOutputParams();
+ CPOperand[] _boundInputs = _inputs.stream()
+ .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+ .toArray(CPOperand[]::new);
+ ArrayList<String> _inputNames = _inputs.stream().map(DataIdentifier::getName)
+ .collect(Collectors.toCollection(ArrayList::new));
+ ArrayList<String> _outputNames = _outputs.stream().map(DataIdentifier::getName)
+ .collect(Collectors.toCollection(ArrayList::new));
+ _inst = new FunctionCallCPInstruction(_funcNS, _funcName, _boundInputs, _inputNames, _outputNames,
+ "update function");
+
+ // Check the inputs of the update function
+ checkInput(_inputs, Expression.DataType.MATRIX, Statement.PS_FEATURES);
+ checkInput(_inputs, Expression.DataType.MATRIX, Statement.PS_LABELS);
+ checkInput(_inputs, Expression.DataType.LIST, Statement.PS_MODEL);
+ if (hyperParams != null) {
+ checkInput(_inputs, Expression.DataType.LIST, Statement.PS_HYPER_PARAMS);
+ }
+
+ // Check the output of the update function
+ if (_outputs.size() != 1) {
+ throw new DMLRuntimeException(
+ String.format("The output of the '%s' function should provide one list containing the gradients.", updFunc));
+ }
+ if (_outputs.get(0).getDataType() != Expression.DataType.LIST) {
+ throw new DMLRuntimeException(
+ String.format("The output of the '%s' function should be type of list.", updFunc));
+ }
+ }
+
+ private void checkInput(ArrayList<DataIdentifier> _inputs, Expression.DataType dt, String pname) {
+ if (_inputs.stream().filter(input -> input.getDataType() == dt && pname.equals(input.getName())).count() != 1) {
+ throw new DMLRuntimeException(
+ String.format("The '%s' function should provide an input of '%s' type named '%s'.", _updFunc, dt, pname));
+ }
+ }
+
+ public void setFeatures(MatrixObject features) {
+ this._features = features;
+ }
+
+ public void setLabels(MatrixObject labels) {
+ this._labels = labels;
+ }
+
+ public void setValFeatures(MatrixObject valFeatures) {
+ this._valFeatures = valFeatures;
+ }
+
+ public void setValLabels(MatrixObject valLabels) {
+ this._valLabels = valLabels;
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
new file mode 100644
index 0000000..6e1cd13
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.stream.Collectors;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.DataIdentifier;
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public abstract class ParamServer {
+
+ public class Gradient {
+ final long _workerID;
+ final ListObject _gradients;
+
+ public Gradient(long workerID, ListObject gradients) {
+ this._workerID = workerID;
+ this._gradients = gradients;
+ }
+ }
+
+ Queue<Gradient> _queue;
+ final Object _lock = new Object();
+ private ListObject _model;
+ private AggregationService _aggService;
+ private Thread _aggThread;
+ private boolean[] _pulledStates;
+
+ protected ParamServer(ListObject model, String aggFunc, Statement.PSFrequency freq,
+ Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum, ListObject hyperParams) {
+ this._queue = new ConcurrentLinkedQueue<>();
+ this._model = model;
+ this._aggService = new AggregationService(aggFunc, freq, updateType, ec, workerNum, hyperParams);
+ this._pulledStates = new boolean[workerNum];
+ this._aggThread = new Thread(_aggService);
+ }
+
+ public abstract void push(long workerID, ListObject value);
+
+ public abstract Data pull(long workerID);
+
+ public void start() {
+ _aggService._alive = true;
+ _aggThread.start();
+ }
+
+ public void stop() {
+ _aggService._alive = false;
+ try {
+ _aggThread.join();
+ } catch (InterruptedException e) {
+ throw new DMLRuntimeException("Parameter server: failed when stopping the server.", e);
+ }
+ }
+
+ public ListObject getResult() {
+ return _model;
+ }
+
+ public boolean getPulledState(int workerID) {
+ return _pulledStates[workerID];
+ }
+
+ public void setPulledState(int workerID, boolean state) {
+ _pulledStates[workerID] = state;
+ }
+
+ private void resetPulledStates() {
+ _pulledStates = new boolean[_pulledStates.length];
+ }
+
+ /**
+ * Inner aggregation service which is for updating the model
+ */
+ @SuppressWarnings("unused")
+ private class AggregationService implements Runnable {
+
+ protected final Log LOG = LogFactory.getLog(AggregationService.class.getName());
+
+ protected ExecutionContext _ec;
+ private Statement.PSFrequency _freq;
+ private Statement.PSUpdateType _updateType;
+ private FunctionCallCPInstruction _inst;
+ private DataIdentifier _output;
+ private boolean _alive;
+ private boolean[] _finishedStates; // Workers' finished states
+
+ AggregationService(String aggFunc, Statement.PSFrequency freq, Statement.PSUpdateType updateType,
+ ExecutionContext ec, int workerNum, ListObject hyperParams) {
+ _ec = ExecutionContextFactory.createContext(ec.getProgram());
+ _freq = freq;
+ _updateType = updateType;
+ if (hyperParams != null) {
+ _ec.setVariable(Statement.PS_HYPER_PARAMS, hyperParams);
+ }
+ _finishedStates = new boolean[workerNum];
+
+ // Fetch the aggregation function
+ String[] keys = DMLProgram.splitFunctionKey(aggFunc);
+ String funcName = keys[0];
+ String funcNS = null;
+ if (keys.length == 2) {
+ funcNS = keys[0];
+ funcName = keys[1];
+ }
+ FunctionProgramBlock func = _ec.getProgram().getFunctionProgramBlock(funcNS, funcName);
+ ArrayList<DataIdentifier> inputs = func.getInputParams();
+ ArrayList<DataIdentifier> outputs = func.getOutputParams();
+
+ // Check the output of the aggregation function
+ if (outputs.size() != 1) {
+ throw new DMLRuntimeException(String.format(
+ "The output of the '%s' function should provide one list containing the updated model.",
+ aggFunc));
+ }
+ if (outputs.get(0).getDataType() != Expression.DataType.LIST) {
+ throw new DMLRuntimeException(
+ String.format("The output of the '%s' function should be type of list.", aggFunc));
+ }
+ _output = outputs.get(0);
+
+ CPOperand[] boundInputs = inputs.stream()
+ .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+ .toArray(CPOperand[]::new);
+ ArrayList<String> inputNames = inputs.stream().map(DataIdentifier::getName)
+ .collect(Collectors.toCollection(ArrayList::new));
+ ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
+ .collect(Collectors.toCollection(ArrayList::new));
+ _inst = new FunctionCallCPInstruction(funcNS, funcName, boundInputs, inputNames, outputNames,
+ "aggregate function");
+ }
+
+ boolean isAlive() {
+ return _alive;
+ }
+
+ private boolean allFinished() {
+ return !ArrayUtils.contains(_finishedStates, false);
+ }
+
+ private void resetFinishedStates() {
+ Arrays.fill(_finishedStates, false);
+ }
+
+ private void setFinishedState(int workerID) {
+ _finishedStates[workerID] = true;
+ }
+
+ @Override
+ public void run() {
+ synchronized (_lock) {
+ while (isAlive()) {
+ do {
+ while (_queue.isEmpty()) {
+ try {
+ _lock.wait();
+ } catch (InterruptedException e) {
+ throw new DMLRuntimeException(
+ "Aggregation service: error when waiting for the coming gradients.", e);
+ }
+ }
+ Gradient p = _queue.remove();
+ if (LOG.isDebugEnabled()) {
+ LOG.debug(String.format("Successfully pulled the gradients [size:%d kb] of worker_%d.",
+ p._gradients.getDataSize() / 1024, p._workerID));
+ }
+
+ setFinishedState((int) p._workerID);
+
+ // Populate the variables table with the gradients and model
+ _ec.setVariable(Statement.PS_GRADIENTS, p._gradients);
+ _ec.setVariable(Statement.PS_MODEL, _model);
+
+ // Invoke the aggregate function
+ _inst.processInstruction(_ec);
+
+ // Get the output
+ ListObject newModel = (ListObject) _ec.getVariable(_output.getName());
+
+ // Update the model with the new output
+ ParamservUtils.cleanupListObject(_ec, _model);
+ ParamservUtils.cleanupListObject(_ec, p._gradients);
+ _model = newModel;
+
+ } while (!allFinished());
+
+ // notify all the workers to get the updated model
+ resetPulledStates();
+ resetFinishedStates();
+ _lock.notifyAll();
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Global parameter is broadcasted successfully.");
+ }
+ }
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
new file mode 100644
index 0000000..54c5d6c
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.MetaDataFormat;
+import org.apache.sysml.runtime.matrix.data.InputInfo;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
+
+public class ParamservUtils {
+
+ /**
+ * Deep copy the list object
+ *
+ * @param lo list object
+ * @return a new copied list object
+ */
+ public static ListObject copyList(ListObject lo) {
+ if (lo.getLength() == 0) {
+ return lo;
+ }
+ List<Data> newData = lo.getNames().stream().map(name -> {
+ Data oldData = lo.slice(name);
+ if (oldData instanceof MatrixObject) {
+ MatrixObject mo = (MatrixObject) oldData;
+ return sliceMatrix(mo, 1, mo.getNumRows());
+ } else if (oldData instanceof ListObject || oldData instanceof FrameObject) {
+ throw new DMLRuntimeException("Copy list: does not support list or frame.");
+ } else {
+ return oldData;
+ }
+ }).collect(Collectors.toList());
+ return new ListObject(newData, lo.getNames());
+ }
+
+ public static void cleanupListObject(ExecutionContext ec, ListObject lo) {
+ ec.getVariables().removeAllIn(new HashSet<>(lo.getNames()));
+ lo.getData().forEach(ParamservUtils::cleanupData);
+ }
+
+ public static void cleanupData(Data data) {
+ if( !(data instanceof CacheableData) )
+ return;
+ CacheableData<?> cd = (CacheableData<?>) data;
+ cd.enableCleanup(true);
+ cd.clearData();
+ }
+
+ /**
+ * Slice the matrix
+ * @param mo input matrix
+ * @param rl low boundary
+ * @param rh high boundary
+ * @return new sliced matrix
+ */
+ public static MatrixObject sliceMatrix(MatrixObject mo, long rl, long rh) {
+ MatrixObject result = new MatrixObject(Expression.ValueType.DOUBLE, null,
+ new MetaDataFormat(new MatrixCharacteristics(-1, -1, -1, -1),
+ OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
+ MatrixBlock tmp = mo.acquireRead();
+ result.acquireModify(tmp.slice((int)rl-1, (int)rh-1, 0,
+ tmp.getNumColumns()-1, new MatrixBlock()));
+ mo.release();
+ result.release();
+ return result;
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java
index 1ca8eab..22b79b0 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPOperand.java
@@ -46,7 +46,7 @@ public class CPOperand
this(name, vt, dt, false);
}
- private CPOperand(String name, ValueType vt, DataType dt, boolean literal) {
+ public CPOperand(String name, ValueType vt, DataType dt, boolean literal) {
_name = name;
_valueType = vt;
_dataType = dt;
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
index 95f03b5..670190c 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ListObject.java
@@ -25,6 +25,7 @@ import java.util.List;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
public class ListObject extends Data {
private static final long serialVersionUID = 3652422061598967358L;
@@ -107,6 +108,19 @@ public class ListObject extends Data {
return (_names == null) ? null : _names.get(ix);
}
+ public boolean isNamedList() {
+ return _names != null;
+ }
+
+ public List<Data> getData() {
+ return _data;
+ }
+
+ public long getDataSize() {
+ return _data.stream().filter(data -> data instanceof CacheableData)
+ .map(data -> ((CacheableData) data).getDataSize()).reduce((l1, l2) -> l1 + l2).get();
+ }
+
@Override
public String getDebugName() {
return toString();
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java
index 51cc4c1..4e5d4c0 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixIndexingCPInstruction.java
@@ -34,8 +34,8 @@ import org.apache.sysml.utils.Statistics;
public final class MatrixIndexingCPInstruction extends IndexingCPInstruction {
- protected MatrixIndexingCPInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl,
- CPOperand cu, CPOperand out, String opcode, String istr) {
+ public MatrixIndexingCPInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu,
+ CPOperand out, String opcode, String istr) {
super(in, rl, ru, cl, cu, out, opcode, istr);
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index ddc56ae..3ab0fc8 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -19,14 +19,62 @@
package org.apache.sysml.runtime.instructions.cp;
+import static org.apache.sysml.parser.Statement.PSFrequency;
+import static org.apache.sysml.parser.Statement.PSModeType;
+import static org.apache.sysml.parser.Statement.PSScheme;
+import static org.apache.sysml.parser.Statement.PSUpdateType;
+import static org.apache.sysml.parser.Statement.PS_AGGREGATION_FUN;
+import static org.apache.sysml.parser.Statement.PS_BATCH_SIZE;
+import static org.apache.sysml.parser.Statement.PS_EPOCHS;
+import static org.apache.sysml.parser.Statement.PS_FEATURES;
+import static org.apache.sysml.parser.Statement.PS_FREQUENCY;
+import static org.apache.sysml.parser.Statement.PS_HYPER_PARAMS;
+import static org.apache.sysml.parser.Statement.PS_LABELS;
+import static org.apache.sysml.parser.Statement.PS_MODE;
+import static org.apache.sysml.parser.Statement.PS_MODEL;
+import static org.apache.sysml.parser.Statement.PS_PARALLELISM;
+import static org.apache.sysml.parser.Statement.PS_SCHEME;
+import static org.apache.sysml.parser.Statement.PS_UPDATE_FUN;
+import static org.apache.sysml.parser.Statement.PS_UPDATE_TYPE;
+import static org.apache.sysml.parser.Statement.PS_VAL_FEATURES;
+import static org.apache.sysml.parser.Statement.PS_VAL_LABELS;
+
+import java.util.ArrayList;
import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
-import org.apache.sysml.parser.Statement;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.matrix.operators.Operator;
+import org.apache.sysml.utils.NativeHelper;
public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruction {
+ private static final int DEFAULT_BATCH_SIZE = 64;
+ private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = PSFrequency.BATCH;
+ private static final int DEFAULT_LEVEL_PARALLELISM = InfrastructureAnalyzer.getLocalParallelism();
+ private static final PSScheme DEFAULT_SCHEME = PSScheme.DISJOINT_CONTIGUOUS;
+
+ //internal local debug level
+ private static final boolean LDEBUG = false;
+
+ static {
+ // for internal debugging only
+ if (LDEBUG) {
+ Logger.getLogger("org.apache.sysml.runtime.controlprogram.paramserv").setLevel((Level) Level.DEBUG);
+ }
+ }
+
protected ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out,
String opcode, String istr) {
super(op, paramsMap, out, opcode, istr);
@@ -34,8 +82,209 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
@Override
public void processInstruction(ExecutionContext ec) {
- ListObject model = (ListObject) ec.getVariable(getParam(Statement.PS_MODEL));
- ListObject outList = model.slice(0, model.getLength() - 1);
- ec.setVariable(output.getName(), outList);
+
+ PSModeType mode = PSModeType.valueOf(getParam(PS_MODE));
+ int workerNum = getWorkerNum(mode);
+ String updFunc = getParam(PS_UPDATE_FUN);
+ String aggFunc = getParam(PS_AGGREGATION_FUN);
+ PSFrequency freq = getFrequency();
+ PSUpdateType updateType = getUpdateType();
+ int epochs = Integer.valueOf(getParam(PS_EPOCHS));
+ if (epochs <= 0) {
+ throw new DMLRuntimeException(
+ String.format("Paramserv function: The argument '%s' could not be less than or equal to 0.",
+ PS_EPOCHS));
+ }
+ long batchSize = getBatchSize();
+
+ // Create the parameter server
+ ListObject model = ec.getListObject(getParam(PS_MODEL));
+ ListObject hyperParams = getHyperParams(ec);
+ ParamServer ps = createPS(mode, aggFunc, freq, updateType, workerNum, model, ec, hyperParams);
+
+ // Create the local workers
+ List<LocalPSWorker> workers = IntStream.range(0, workerNum)
+ .mapToObj(i -> new LocalPSWorker((long) i, updFunc, freq, epochs, batchSize, hyperParams, ec, ps))
+ .collect(Collectors.toList());
+
+ // Do data partition
+ doDataPartition(ec, workers);
+
+ // Create the worker threads
+ List<Thread> threads = workers.stream().map(Thread::new).collect(Collectors.toList());
+
+ // Start the ps
+ ps.start();
+
+ // Start the workers
+ threads.forEach(Thread::start);
+
+ // Wait for the workers stopping
+ threads.forEach(thread -> {
+ try {
+ thread.join();
+ } catch (InterruptedException e) {
+ throw new DMLRuntimeException("Paramserv function: Failed to join the worker threads.", e);
+ }
+ });
+
+ ps.stop();
+
+ // Create the output
+ ListObject result = ps.getResult();
+ ec.setVariable(output.getName(), result);
+ }
+
+ private PSUpdateType getUpdateType() {
+ PSUpdateType updType = PSUpdateType.valueOf(getParam(PS_UPDATE_TYPE));
+ switch (updType) {
+ case ASP:
+ case SSP:
+ throw new DMLRuntimeException(String.format("Not support update type '%s'.", updType));
+ case BSP:
+ break;
+ }
+ return updType;
+ }
+
+ private PSFrequency getFrequency() {
+ if (!getParameterMap().containsKey(PS_FREQUENCY)) {
+ return DEFAULT_UPDATE_FREQUENCY;
+ }
+ PSFrequency freq = PSFrequency.valueOf(getParam(PS_FREQUENCY));
+ switch (freq) {
+ case EPOCH:
+ throw new DMLRuntimeException("Not support epoch update frequency.");
+ case BATCH:
+ break;
+ }
+ return freq;
+ }
+
+ /**
+ * Get the worker numbers according to the vcores
+ *
+ * @param mode execution mode
+ * @return worker numbers
+ */
+ private int getWorkerNum(PSModeType mode) {
+ int workerNum = DEFAULT_LEVEL_PARALLELISM;
+ if (getParameterMap().containsKey(PS_PARALLELISM)) {
+ workerNum = Integer.valueOf(getParam(PS_PARALLELISM));
+ }
+ switch (mode) {
+ case LOCAL:
+ //FIXME: this is a workaround for a maximum number of buffers in openblas
+ //However, the root cause is a missing function preparation for each worker
+ //(i.e., deep copy with unique file names, and reduced degree of parallelism)
+ int vcores = InfrastructureAnalyzer.getLocalParallelism();
+ if ("openblas".equals(NativeHelper.getCurrentBLAS())) {
+ workerNum = Math.min(workerNum, vcores / 2);
+ } else {
+ workerNum = Math.min(workerNum, vcores);
+ }
+ break;
+ case REMOTE_SPARK:
+ throw new DMLRuntimeException("Do not support remote spark.");
+ }
+ return workerNum;
+ }
+
+ /**
+ * Create a server which serves the local or remote workers
+ *
+ * @return parameter server
+ */
+ private ParamServer createPS(PSModeType mode, String aggFunc, PSFrequency freq, PSUpdateType updateType,
+ int workerNum, ListObject model, ExecutionContext ec, ListObject hyperParams) {
+ ParamServer ps = null;
+ switch (mode) {
+ case LOCAL:
+ ps = new LocalParamServer(model, aggFunc, freq, updateType, ec, workerNum, hyperParams);
+ break;
+ case REMOTE_SPARK:
+ throw new DMLRuntimeException("Do not support remote spark.");
+ }
+ return ps;
+ }
+
+ private long getBatchSize() {
+ if (!getParameterMap().containsKey(PS_BATCH_SIZE)) {
+ return DEFAULT_BATCH_SIZE;
+ }
+ long batchSize = Integer.valueOf(getParam(PS_BATCH_SIZE));
+ if (batchSize <= 0) {
+ throw new DMLRuntimeException(String.format(
+ "Paramserv function: the number of argument '%s' could not be less than or equal to 0.",
+ PS_BATCH_SIZE));
+ }
+ return batchSize;
+ }
+
+ private ListObject getHyperParams(ExecutionContext ec) {
+ ListObject hyperparams = null;
+ if (getParameterMap().containsKey(PS_HYPER_PARAMS)) {
+ hyperparams = ec.getListObject(getParam(PS_HYPER_PARAMS));
+ }
+ return hyperparams;
+ }
+
+ private void doDataPartition(ExecutionContext ec, List<LocalPSWorker> workers) {
+ MatrixObject features = ec.getMatrixObject(getParam(PS_FEATURES));
+ MatrixObject labels = ec.getMatrixObject(getParam(PS_LABELS));
+ MatrixObject valFeatures = ec.getMatrixObject(getParam(PS_VAL_FEATURES));
+ MatrixObject valLabels = ec.getMatrixObject(getParam(PS_VAL_LABELS));
+ PSScheme scheme = DEFAULT_SCHEME;
+ if (getParameterMap().containsKey(PS_SCHEME)) {
+ scheme = PSScheme.valueOf(getParam(PS_SCHEME));
+ }
+ switch (scheme) {
+ case DISJOINT_CONTIGUOUS:
+ disjointContiguous(features, labels, valFeatures, valLabels, workers);
+ break;
+ case DISJOINT_RANDOM:
+ case OVERLAP_RESHUFFLE:
+ case DISJOINT_ROUND_ROBIN:
+ throw new DMLRuntimeException(
+ String.format("Paramserv function: the scheme '%s' is not supported.", scheme));
+ }
+ }
+
+ private void disjointContiguous(MatrixObject features, MatrixObject labels, MatrixObject valFeatures,
+ MatrixObject valLabels, List<LocalPSWorker> workers) {
+ // training data
+ List<MatrixObject> pfs = disjointContiguous(workers.size(), features);
+ List<MatrixObject> pls = disjointContiguous(workers.size(), labels);
+ if (pfs.size() < workers.size()) {
+ LOG.warn(String.format(
+ "There is only %d batches of data but has %d workers. Hence, reset the number of workers with %d.",
+ pfs.size(), workers.size(), pfs.size()));
+ workers = workers.subList(0, pfs.size());
+ }
+ for (int i = 0; i < workers.size(); i++) {
+ workers.get(i).setFeatures(pfs.get(i));
+ workers.get(i).setLabels(pls.get(i));
+ }
+
+ // validation data
+ List<MatrixObject> pvfs = disjointContiguous(workers.size(), valFeatures);
+ List<MatrixObject> pvls = disjointContiguous(workers.size(), valLabels);
+ for (int i = 0; i < workers.size(); i++) {
+ workers.get(i).setValFeatures(pvfs.get(i));
+ workers.get(i).setValLabels(pvls.get(i));
+ }
+ }
+
+ private List<MatrixObject> disjointContiguous(int workerNum, MatrixObject mo) {
+ List<MatrixObject> list = new ArrayList<>();
+ long stepSize = (long) Math.ceil(mo.getNumRows() / workerNum);
+ long begin = 1;
+ while (begin < mo.getNumRows()) {
+ long end = Math.min(begin + stepSize, mo.getNumRows());
+ MatrixObject pmo = ParamservUtils.sliceMatrix(mo, begin, end);
+ list.add(pmo);
+ begin = end + 1;
+ }
+ return list;
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
index 47ea66e..43f5229 100644
--- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
@@ -1250,9 +1250,11 @@ public abstract class AutomatedTestBase
if (exceptionExpected)
fail("expected exception which has not been raised: " + expectedException);
} catch (Exception e) {
- if (exceptionExpected && e.getClass().equals(expectedException) && errMessage != null
- && !e.getMessage().contains(errMessage)) {
- fail("expected exception message has not been raised: " + errMessage);
+ if (errMessage != null && !errMessage.equals("")) {
+ boolean result = rCompareException(exceptionExpected, errMessage, e, false);
+ if (exceptionExpected && !result) {
+ fail(String.format("expected exception message '%s' has not been raised.", errMessage));
+ }
}
if (!exceptionExpected || (expectedException != null && !(e.getClass().equals(expectedException)))) {
e.printStackTrace();
@@ -1269,6 +1271,16 @@ public abstract class AutomatedTestBase
}
}
+ private boolean rCompareException(boolean exceptionExpected, String errMessage, Throwable e, boolean result) {
+ if (e.getCause() != null) {
+ result |= rCompareException(exceptionExpected, errMessage, e.getCause(), result);
+ }
+ if (exceptionExpected && errMessage != null && e.getMessage().contains(errMessage)) {
+ result = true;
+ }
+ return result;
+ }
+
public void cleanupScratchSpace()
{
try
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
index 1b227f1..6370099 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/ParamservFuncTest.java
@@ -32,7 +32,8 @@ public class ParamservFuncTest extends AutomatedTestBase {
private static final String TEST_NAME4 = "paramserv-wrong-type-args";
private static final String TEST_NAME5 = "paramserv-wrong-args";
private static final String TEST_NAME6 = "paramserv-wrong-args2";
- private static final String TEST_NAME7 = "paramserv-ipa-test";
+ private static final String TEST_NAME7 = "paramserv-nn-test";
+ private static final String TEST_NAME8 = "paramserv-minimum-version";
private static final String TEST_DIR = "functions/paramserv/";
private static final String TEST_CLASS_DIR = TEST_DIR + ParamservFuncTest.class.getSimpleName() + "/";
@@ -48,53 +49,59 @@ public class ParamservFuncTest extends AutomatedTestBase {
addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {}));
addTestConfiguration(TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] {}));
addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {}));
+ addTestConfiguration(TEST_NAME8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME8, new String[] {}));
}
@Test
public void testParamservWithAllArgs() {
- runDMLTest(TEST_NAME1, true, false, null, null);
+ runDMLTest(TEST_NAME1, false, null, null);
}
@Test
public void testParamservWithoutOptionalArgs() {
- runDMLTest(TEST_NAME2, true, false, null, null);
+ runDMLTest(TEST_NAME2, false, null, null);
}
@Test
public void testParamservMissArgs() {
final String errmsg = "Named parameter 'features' missing. Please specify the input.";
- runDMLTest(TEST_NAME3, true, true, DMLException.class, errmsg);
+ runDMLTest(TEST_NAME3, true, DMLException.class, errmsg);
}
@Test
public void testParamservWrongTypeArgs() {
final String errmsg = "Input to PARAMSERV::model must be of type 'LIST'. It should not be of type 'MATRIX'";
- runDMLTest(TEST_NAME4, true, true, DMLException.class, errmsg);
+ runDMLTest(TEST_NAME4, true, DMLException.class, errmsg);
}
@Test
public void testParamservWrongArgs() {
final String errmsg = "Function PARAMSERV does not support value 'NSP' as the 'utype' parameter.";
- runDMLTest(TEST_NAME5, true, true, DMLException.class, errmsg);
+ runDMLTest(TEST_NAME5, true, DMLException.class, errmsg);
}
@Test
public void testParamservWrongArgs2() {
final String errmsg = "Invalid parameters for PARAMSERV: [modelList, val_featur=X_val]";
- runDMLTest(TEST_NAME6, true, true, DMLException.class, errmsg);
+ runDMLTest(TEST_NAME6, true, DMLException.class, errmsg);
}
@Test
- public void testParamservIpaTest() {
- runDMLTest(TEST_NAME7, true, false, null, "1");
+ public void testParamservNNTest() {
+ runDMLTest(TEST_NAME7, false, null, null);
}
- private void runDMLTest(String testname, boolean newWay, boolean exceptionExpected, Class<?> exceptionClass,
+ @Test
+ public void testParamservMinimumVersionTest() {
+ runDMLTest(TEST_NAME8, false, null, null);
+ }
+
+ private void runDMLTest(String testname, boolean exceptionExpected, Class<?> exceptionClass,
String errmsg) {
TestConfiguration config = getTestConfiguration(testname);
loadTestConfiguration(config);
programArgs = new String[] { "-explain" };
fullDMLScriptName = HOME + testname + ".dml";
- runTest(newWay, exceptionExpected, exceptionClass, errmsg, -1);
+ runTest(true, exceptionExpected, exceptionClass, errmsg, -1);
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
new file mode 100644
index 0000000..2a3bbe2
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
@@ -0,0 +1,383 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * MNIST LeNet Example
+ */
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/conv2d_builtin.dml") as conv2d
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/dropout.dml") as dropout
+source("nn/layers/l2_reg.dml") as l2_reg
+source("nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+train = function(matrix[double] X, matrix[double] Y,
+ matrix[double] X_val, matrix[double] Y_val,
+ int C, int Hin, int Win, int epochs, int workers)
+ return (matrix[double] W1, matrix[double] b1,
+ matrix[double] W2, matrix[double] b2,
+ matrix[double] W3, matrix[double] b3,
+ matrix[double] W4, matrix[double] b4) {
+ /*
+ * Trains a convolutional net using the "LeNet" architecture.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector. The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ * - X: Input data matrix, of shape (N, C*Hin*Win).
+ * - Y: Target matrix, of shape (N, K).
+ * - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+ * - Y_val: Target validation matrix, of shape (N, K).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - epochs: Total number of full training loops over the full data set.
+ *
+ * Outputs:
+ * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+ * - b1: 1st layer biases vector, of shape (F1, 1).
+ * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+ * - b2: 2nd layer biases vector, of shape (F2, 1).
+ * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3).
+ * - b3: 3rd layer biases vector, of shape (1, N3).
+ * - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+ * - b4: 4th layer biases vector, of shape (1, K).
+ */
+ N = nrow(X)
+ K = ncol(Y)
+
+ # Create network:
+ # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+ Hf = 5 # filter height
+ Wf = 5 # filter width
+ stride = 1
+ pad = 2 # For same dimensions, (Hf - stride) / 2
+
+ F1 = 32 # num conv filters in conv1
+ F2 = 64 # num conv filters in conv2
+ N3 = 512 # num nodes in affine3
+ # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes)
+
+ [W1, b1] = conv2d::init(F1, C, Hf, Wf) # inputs: (N, C*Hin*Win)
+ [W2, b2] = conv2d::init(F2, F1, Hf, Wf) # inputs: (N, F1*(Hin/2)*(Win/2))
+ [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3) # inputs: (N, F2*(Hin/2/2)*(Win/2/2))
+ [W4, b4] = affine::init(N3, K) # inputs: (N, N3)
+ W4 = W4 / sqrt(2) # different initialization, since being fed into softmax, instead of relu
+
+ # Initialize SGD w/ Nesterov momentum optimizer
+ lr = 0.01 # learning rate
+ mu = 0.9 #0.5 # momentum
+ decay = 0.95 # learning rate decay constant
+ vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+ vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+ vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+ vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+ # Regularization
+ lambda = 5e-04
+
+ # Create the model object
+ modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+
+ # Create the hyper parameter list
+ params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+
+ # Use paramserv function
+ modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation", mode="LOCAL", utype="BSP", freq="BATCH", epochs=epochs, batchsize=64, k=workers, scheme="DISJOINT_CONTIGUOUS", hyperparams=params, checkpointing="NONE")
+
+ W1 = as.matrix(modelList2["W1"])
+ b1 = as.matrix(modelList2["b1"])
+ W2 = as.matrix(modelList2["W2"])
+ b2 = as.matrix(modelList2["b2"])
+ W3 = as.matrix(modelList2["W3"])
+ b3 = as.matrix(modelList2["b3"])
+ W4 = as.matrix(modelList2["W4"])
+ b4 = as.matrix(modelList2["b4"])
+
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(matrix[double] features,
+ matrix[double] labels,
+ list[unknown] hyperparams,
+ list[unknown] model)
+ return (list[unknown] gradients) {
+
+# PB: not be able to get scalar from list
+
+ C = 1
+ Hin = 28
+ Win = 28
+ Hf = 5
+ Wf = 5
+ stride = 1
+ pad = 2
+ lambda = 5e-04
+ F1 = 32
+ F2 = 64
+ N3 = 512
+ W1 = as.matrix(model["W1"])
+ b1 = as.matrix(model["b1"])
+ W2 = as.matrix(model["W2"])
+ b2 = as.matrix(model["b2"])
+ W3 = as.matrix(model["W3"])
+ b3 = as.matrix(model["b3"])
+ W4 = as.matrix(model["W4"])
+ b4 = as.matrix(model["b4"])
+
+ # Compute forward pass
+ ## layer 1: conv1 -> relu1 -> pool1
+ [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf,
+ stride, stride, pad, pad)
+ outr1 = relu::forward(outc1)
+ [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2,
+ strideh=2, stridew=2, pad=0, pad=0)
+ ## layer 2: conv2 -> relu2 -> pool2
+ [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+ stride, stride, pad, pad)
+ outr2 = relu::forward(outc2)
+ [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2,
+ strideh=2, stridew=2, pad=0, pad=0)
+ ## layer 3: affine3 -> relu3 -> dropout
+ outa3 = affine::forward(outp2, W3, b3)
+ outr3 = relu::forward(outa3)
+ [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+ ## layer 4: affine4 -> softmax
+ outa4 = affine::forward(outd3, W4, b4)
+ probs = softmax::forward(outa4)
+
+ # Compute data backward pass
+ ## loss:
+ dprobs = cross_entropy_loss::backward(probs, labels)
+ ## layer 4: affine4 -> softmax
+ douta4 = softmax::backward(dprobs, outa4)
+ [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+ ## layer 3: affine3 -> relu3 -> dropout
+ doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+ douta3 = relu::backward(doutr3, outa3)
+ [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+ ## layer 2: conv2 -> relu2 -> pool2
+ doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2,
+ strideh=2, stridew=2, pad=0, pad=0)
+ doutc2 = relu::backward(doutr2, outc2)
+ [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, F1,
+ Houtp1, Woutp1, Hf, Wf, stride, stride, pad, pad)
+ ## layer 1: conv1 -> relu1 -> pool1
+ doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2,
+ strideh=2, stridew=2, pad=0, pad=0)
+ doutc1 = relu::backward(doutr1, outc1)
+ [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win,
+ Hf, Wf, stride, stride, pad, pad)
+
+ # Compute regularization backward pass
+ dW1_reg = l2_reg::backward(W1, lambda)
+ dW2_reg = l2_reg::backward(W2, lambda)
+ dW3_reg = l2_reg::backward(W3, lambda)
+ dW4_reg = l2_reg::backward(W4, lambda)
+ dW1 = dW1 + dW1_reg
+ dW2 = dW2 + dW2_reg
+ dW3 = dW3 + dW3_reg
+ dW4 = dW4 + dW4_reg
+
+ gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, db3=db3, db4=db4)
+}
+
+# PB: how to handle the velocity? (put into the model)
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+ list[unknown] gradients,
+ list[unknown] hyperparams)
+ return (list[unknown] modelResult) {
+
+ W1 = as.matrix(model["W1"])
+ W2 = as.matrix(model["W2"])
+ W3 = as.matrix(model["W3"])
+ W4 = as.matrix(model["W4"])
+ b1 = as.matrix(model["b1"])
+ b2 = as.matrix(model["b2"])
+ b3 = as.matrix(model["b3"])
+ b4 = as.matrix(model["b4"])
+ dW1 = as.matrix(gradients["dW1"])
+ dW2 = as.matrix(gradients["dW2"])
+ dW3 = as.matrix(gradients["dW3"])
+ dW4 = as.matrix(gradients["dW4"])
+ db1 = as.matrix(gradients["db1"])
+ db2 = as.matrix(gradients["db2"])
+ db3 = as.matrix(gradients["db3"])
+ db4 = as.matrix(gradients["db4"])
+ vW1 = as.matrix(model["vW1"])
+ vW2 = as.matrix(model["vW2"])
+ vW3 = as.matrix(model["vW3"])
+ vW4 = as.matrix(model["vW4"])
+ vb1 = as.matrix(model["vb1"])
+ vb2 = as.matrix(model["vb2"])
+ vb3 = as.matrix(model["vb3"])
+ vb4 = as.matrix(model["vb4"])
+ lr = 0.01
+ mu = 0.9
+
+ # Optimize with SGD w/ Nesterov momentum
+ [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+ [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+ [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+ [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+ [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+ [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+ [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
+ [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
+
+ modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+ }
+
+predict = function(matrix[double] X, int C, int Hin, int Win,
+ matrix[double] W1, matrix[double] b1,
+ matrix[double] W2, matrix[double] b2,
+ matrix[double] W3, matrix[double] b3,
+ matrix[double] W4, matrix[double] b4)
+ return (matrix[double] probs) {
+ /*
+ * Computes the class probability predictions of a convolutional
+ * net using the "LeNet" architecture.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.
+ *
+ * Inputs:
+ * - X: Input data matrix, of shape (N, C*Hin*Win).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+ * - b1: 1st layer biases vector, of shape (F1, 1).
+ * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+ * - b2: 2nd layer biases vector, of shape (F2, 1).
+ * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3).
+ * - b3: 3rd layer biases vector, of shape (1, N3).
+ * - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+ * - b4: 4th layer biases vector, of shape (1, K).
+ *
+ * Outputs:
+ * - probs: Class probabilities, of shape (N, K).
+ */
+ N = nrow(X)
+
+ # Network:
+ # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+ Hf = 5 # filter height
+ Wf = 5 # filter width
+ stride = 1
+ pad = 2 # For same dimensions, (Hf - stride) / 2
+
+ F1 = nrow(W1) # num conv filters in conv1
+ F2 = nrow(W2) # num conv filters in conv2
+ N3 = ncol(W3) # num nodes in affine3
+ K = ncol(W4) # num nodes in affine4, equal to number of target dimensions (num classes)
+
+ # Compute predictions over mini-batches
+ probs = matrix(0, rows=N, cols=K)
+ batch_size = 64
+ iters = ceil(N / batch_size)
+ for(i in 1:iters) {
+ # Get next batch
+ beg = ((i-1) * batch_size) %% N + 1
+ end = min(N, beg + batch_size - 1)
+ X_batch = X[beg:end,]
+
+ # Compute forward pass
+ ## layer 1: conv1 -> relu1 -> pool1
+ [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, stride, stride,
+ pad, pad)
+ outr1 = relu::forward(outc1)
+ [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2,
+ strideh=2, stridew=2, pad=0, pad=0)
+ ## layer 2: conv2 -> relu2 -> pool2
+ [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+ stride, stride, pad, pad)
+ outr2 = relu::forward(outc2)
+ [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2,
+ strideh=2, stridew=2, pad=0, pad=0)
+ ## layer 3: affine3 -> relu3
+ outa3 = affine::forward(outp2, W3, b3)
+ outr3 = relu::forward(outa3)
+ ## layer 4: affine4 -> softmax
+ outa4 = affine::forward(outr3, W4, b4)
+ probs_batch = softmax::forward(outa4)
+
+ # Store predictions
+ probs[beg:end,] = probs_batch
+ }
+}
+
+eval = function(matrix[double] probs, matrix[double] Y)
+ return (double loss, double accuracy) {
+ /*
+ * Evaluates a convolutional net using the "LeNet" architecture.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples. The targets, Y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ * - probs: Class probabilities, of shape (N, K).
+ * - Y: Target matrix, of shape (N, K).
+ *
+ * Outputs:
+ * - loss: Scalar loss, of shape (1).
+ * - accuracy: Scalar accuracy, of shape (1).
+ */
+ # Compute loss & accuracy
+ loss = cross_entropy_loss::forward(probs, Y)
+ correct_pred = rowIndexMax(probs) == rowIndexMax(Y)
+ accuracy = mean(correct_pred)
+}
+
+generate_dummy_data = function()
+ return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {
+ /*
+ * Generate a dummy dataset similar to the MNIST dataset.
+ *
+ * Outputs:
+ * - X: Input data matrix, of shape (N, D).
+ * - Y: Target matrix, of shape (N, K).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ */
+ # Generate dummy input data
+ N = 1024 # num examples
+ C = 1 # num input channels
+ Hin = 28 # input height
+ Win = 28 # input width
+ K = 10 # num target classes
+ X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+ classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform"))
+ Y = table(seq(1, N), classes) # one-hot encoding
+}
+
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
new file mode 100644
index 0000000..2ef7411
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
@@ -0,0 +1,377 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * MNIST LeNet Example
+ */
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/conv2d_builtin.dml") as conv2d
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/dropout.dml") as dropout
+source("nn/layers/l2_reg.dml") as l2_reg
+source("nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+train = function(matrix[double] X, matrix[double] Y,
+ matrix[double] X_val, matrix[double] Y_val,
+ int C, int Hin, int Win, int epochs, int workers)
+ return (matrix[double] W1, matrix[double] b1,
+ matrix[double] W2, matrix[double] b2,
+ matrix[double] W3, matrix[double] b3,
+ matrix[double] W4, matrix[double] b4) {
+ /*
+ * Trains a convolutional net using the "LeNet" architecture.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector. The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ * - X: Input data matrix, of shape (N, C*Hin*Win).
+ * - Y: Target matrix, of shape (N, K).
+ * - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+ * - Y_val: Target validation matrix, of shape (N, K).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - epochs: Total number of full training loops over the full data set.
+ *
+ * Outputs:
+ * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+ * - b1: 1st layer biases vector, of shape (F1, 1).
+ * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+ * - b2: 2nd layer biases vector, of shape (F2, 1).
+ * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3).
+ * - b3: 3rd layer biases vector, of shape (1, N3).
+ * - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+ * - b4: 4th layer biases vector, of shape (1, K).
+ */
+ N = nrow(X)
+ K = ncol(Y)
+
+ # Create network:
+ # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+ Hf = 5 # filter height
+ Wf = 5 # filter width
+ stride = 1
+ pad = 2 # For same dimensions, (Hf - stride) / 2
+
+ F1 = 32 # num conv filters in conv1
+ F2 = 64 # num conv filters in conv2
+ N3 = 512 # num nodes in affine3
+ # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes)
+
+ [W1, b1] = conv2d::init(F1, C, Hf, Wf) # inputs: (N, C*Hin*Win)
+ [W2, b2] = conv2d::init(F2, F1, Hf, Wf) # inputs: (N, F1*(Hin/2)*(Win/2))
+ [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3) # inputs: (N, F2*(Hin/2/2)*(Win/2/2))
+ [W4, b4] = affine::init(N3, K) # inputs: (N, N3)
+ W4 = W4 / sqrt(2) # different initialization, since being fed into softmax, instead of relu
+
+ # Initialize SGD w/ Nesterov momentum optimizer
+ lr = 0.01 # learning rate
+ mu = 0.9 #0.5 # momentum
+ decay = 0.95 # learning rate decay constant
+ vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+ vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+ vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+ vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+ # Regularization
+ lambda = 5e-04
+
+ # Create the model object
+ modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+
+ # Create the hyper parameter list
+ params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+
+ # Use paramserv function
+ modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::aggregation", mode="LOCAL", utype="BSP", epochs=epochs, hyperparams=params)
+
+ W1 = as.matrix(modelList2["W1"])
+ b1 = as.matrix(modelList2["b1"])
+ W2 = as.matrix(modelList2["W2"])
+ b2 = as.matrix(modelList2["b2"])
+ W3 = as.matrix(modelList2["W3"])
+ b3 = as.matrix(modelList2["b3"])
+ W4 = as.matrix(modelList2["W4"])
+ b4 = as.matrix(modelList2["b4"])
+
+}
+
+gradients = function(matrix[double] features,
+ matrix[double] labels,
+ list[unknown] hyperparams,
+ list[unknown] model)
+ return (list[unknown] gradients) {
+
+ C = 1
+ Hin = 28
+ Win = 28
+ Hf = 5
+ Wf = 5
+ stride = 1
+ pad = 2
+ lambda = 5e-04
+ F1 = 32
+ F2 = 64
+ N3 = 512
+ W1 = as.matrix(model["W1"])
+ b1 = as.matrix(model["b1"])
+ W2 = as.matrix(model["W2"])
+ b2 = as.matrix(model["b2"])
+ W3 = as.matrix(model["W3"])
+ b3 = as.matrix(model["b3"])
+ W4 = as.matrix(model["W4"])
+ b4 = as.matrix(model["b4"])
+
+ # Compute forward pass
+ ## layer 1: conv1 -> relu1 -> pool1
+ [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf,
+ stride, stride, pad, pad)
+ outr1 = relu::forward(outc1)
+ [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2,
+ strideh=2, stridew=2, pad=0, pad=0)
+ ## layer 2: conv2 -> relu2 -> pool2
+ [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+ stride, stride, pad, pad)
+ outr2 = relu::forward(outc2)
+ [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2,
+ strideh=2, stridew=2, pad=0, pad=0)
+ ## layer 3: affine3 -> relu3 -> dropout
+ outa3 = affine::forward(outp2, W3, b3)
+ outr3 = relu::forward(outa3)
+ [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+ ## layer 4: affine4 -> softmax
+ outa4 = affine::forward(outd3, W4, b4)
+ probs = softmax::forward(outa4)
+
+ # Compute data backward pass
+ ## loss:
+ dprobs = cross_entropy_loss::backward(probs, labels)
+ ## layer 4: affine4 -> softmax
+ douta4 = softmax::backward(dprobs, outa4)
+ [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+ ## layer 3: affine3 -> relu3 -> dropout
+ doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+ douta3 = relu::backward(doutr3, outa3)
+ [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+ ## layer 2: conv2 -> relu2 -> pool2
+ doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2,
+ strideh=2, stridew=2, pad=0, pad=0)
+ doutc2 = relu::backward(doutr2, outc2)
+ [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, F1,
+ Houtp1, Woutp1, Hf, Wf, stride, stride, pad, pad)
+ ## layer 1: conv1 -> relu1 -> pool1
+ doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2,
+ strideh=2, stridew=2, pad=0, pad=0)
+ doutc1 = relu::backward(doutr1, outc1)
+ [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win,
+ Hf, Wf, stride, stride, pad, pad)
+
+ # Compute regularization backward pass
+ dW1_reg = l2_reg::backward(W1, lambda)
+ dW2_reg = l2_reg::backward(W2, lambda)
+ dW3_reg = l2_reg::backward(W3, lambda)
+ dW4_reg = l2_reg::backward(W4, lambda)
+ dW1 = dW1 + dW1_reg
+ dW2 = dW2 + dW2_reg
+ dW3 = dW3 + dW3_reg
+ dW4 = dW4 + dW4_reg
+
+ gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, db3=db3, db4=db4)
+
+}
+
+# how to handle the velocity?
+aggregation = function(list[unknown] model,
+ list[unknown] gradients,
+ list[unknown] hyperparams)
+ return (list[unknown] modelResult) {
+
+ W1 = as.matrix(model["W1"])
+ W2 = as.matrix(model["W2"])
+ W3 = as.matrix(model["W3"])
+ W4 = as.matrix(model["W4"])
+ b1 = as.matrix(model["b1"])
+ b2 = as.matrix(model["b2"])
+ b3 = as.matrix(model["b3"])
+ b4 = as.matrix(model["b4"])
+ dW1 = as.matrix(gradients["dW1"])
+ dW2 = as.matrix(gradients["dW2"])
+ dW3 = as.matrix(gradients["dW3"])
+ dW4 = as.matrix(gradients["dW4"])
+ db1 = as.matrix(gradients["db1"])
+ db2 = as.matrix(gradients["db2"])
+ db3 = as.matrix(gradients["db3"])
+ db4 = as.matrix(gradients["db4"])
+ vW1 = as.matrix(model["vW1"])
+ vW2 = as.matrix(model["vW2"])
+ vW3 = as.matrix(model["vW3"])
+ vW4 = as.matrix(model["vW4"])
+ vb1 = as.matrix(model["vb1"])
+ vb2 = as.matrix(model["vb2"])
+ vb3 = as.matrix(model["vb3"])
+ vb4 = as.matrix(model["vb4"])
+ lr = 0.01
+ mu = 0.9
+
+ # Optimize with SGD w/ Nesterov momentum
+ [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+ [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+ [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+ [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+ [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+ [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+ [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
+ [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
+
+ modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+ }
+
+predict = function(matrix[double] X, int C, int Hin, int Win,
+ matrix[double] W1, matrix[double] b1,
+ matrix[double] W2, matrix[double] b2,
+ matrix[double] W3, matrix[double] b3,
+ matrix[double] W4, matrix[double] b4)
+ return (matrix[double] probs) {
+ /*
+ * Computes the class probability predictions of a convolutional
+ * net using the "LeNet" architecture.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.
+ *
+ * Inputs:
+ * - X: Input data matrix, of shape (N, C*Hin*Win).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+ * - b1: 1st layer biases vector, of shape (F1, 1).
+ * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+ * - b2: 2nd layer biases vector, of shape (F2, 1).
+ * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3).
+ * - b3: 3rd layer biases vector, of shape (1, N3).
+ * - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+ * - b4: 4th layer biases vector, of shape (1, K).
+ *
+ * Outputs:
+ * - probs: Class probabilities, of shape (N, K).
+ */
+ N = nrow(X)
+
+ # Network:
+ # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+ Hf = 5 # filter height
+ Wf = 5 # filter width
+ stride = 1
+ pad = 2 # For same dimensions, (Hf - stride) / 2
+
+ F1 = nrow(W1) # num conv filters in conv1
+ F2 = nrow(W2) # num conv filters in conv2
+ N3 = ncol(W3) # num nodes in affine3
+ K = ncol(W4) # num nodes in affine4, equal to number of target dimensions (num classes)
+
+ # Compute predictions over mini-batches
+ probs = matrix(0, rows=N, cols=K)
+ batch_size = 64
+ iters = ceil(N / batch_size)
+ for(i in 1:iters) {
+ # Get next batch
+ beg = ((i-1) * batch_size) %% N + 1
+ end = min(N, beg + batch_size - 1)
+ X_batch = X[beg:end,]
+
+ # Compute forward pass
+ ## layer 1: conv1 -> relu1 -> pool1
+ [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, stride, stride,
+ pad, pad)
+ outr1 = relu::forward(outc1)
+ [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, Hf=2, Wf=2,
+ strideh=2, stridew=2, pad=0, pad=0)
+ ## layer 2: conv2 -> relu2 -> pool2
+ [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+ stride, stride, pad, pad)
+ outr2 = relu::forward(outc2)
+ [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, Hf=2, Wf=2,
+ strideh=2, stridew=2, pad=0, pad=0)
+ ## layer 3: affine3 -> relu3
+ outa3 = affine::forward(outp2, W3, b3)
+ outr3 = relu::forward(outa3)
+ ## layer 4: affine4 -> softmax
+ outa4 = affine::forward(outr3, W4, b4)
+ probs_batch = softmax::forward(outa4)
+
+ # Store predictions
+ probs[beg:end,] = probs_batch
+ }
+}
+
+eval = function(matrix[double] probs, matrix[double] Y)
+ return (double loss, double accuracy) {
+ /*
+ * Evaluates a convolutional net using the "LeNet" architecture.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples. The targets, Y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ * - probs: Class probabilities, of shape (N, K).
+ * - Y: Target matrix, of shape (N, K).
+ *
+ * Outputs:
+ * - loss: Scalar loss, of shape (1).
+ * - accuracy: Scalar accuracy, of shape (1).
+ */
+ # Compute loss & accuracy
+ loss = cross_entropy_loss::forward(probs, Y)
+ correct_pred = rowIndexMax(probs) == rowIndexMax(Y)
+ accuracy = mean(correct_pred)
+}
+
+generate_dummy_data = function()
+ return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {
+ /*
+ * Generate a dummy dataset similar to the MNIST dataset.
+ *
+ * Outputs:
+ * - X: Input data matrix, of shape (N, D).
+ * - Y: Target matrix, of shape (N, K).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ */
+ # Generate dummy input data
+ N = 1024 # num examples
+ C = 1 # num input channels
+ Hin = 28 # input height
+ Win = 28 # input width
+ K = 10 # num target classes
+ X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+ classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform"))
+ Y = table(seq(1, N), classes) # one-hot encoding
+}
+
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/paramserv-all-args.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-all-args.dml b/src/test/scripts/functions/paramserv/paramserv-all-args.dml
index bcb3ac3..ec6e087 100644
--- a/src/test/scripts/functions/paramserv/paramserv-all-args.dml
+++ b/src/test/scripts/functions/paramserv/paramserv-all-args.dml
@@ -20,7 +20,7 @@
#-------------------------------------------------------------
e1 = "element1"
-paramsList = list(e1)
+paramsList = list(e1=e1)
X = matrix(1, rows=2, cols=3)
Y = matrix(2, rows=2, cols=3)
X_val = matrix(3, rows=2, cols=3)
@@ -35,7 +35,7 @@ aggregation = function (matrix[double] input) return (matrix[double] output) {
}
e2 = "element2"
-hps = list(e2)
+hps = list(e2=e2)
# Use paramserv function
paramsList2 = paramserv(model=paramsList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=hps, checkpointing="NONE")
http://git-wip-us.apache.org/repos/asf/systemml/blob/97018d4e/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml b/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml
deleted file mode 100644
index 5aed767..0000000
--- a/src/test/scripts/functions/paramserv/paramserv-ipa-test.dml
+++ /dev/null
@@ -1,47 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-e1 = "element1"
-paramsList = list(e1)
-X = matrix(1, rows=2, cols=3)
-Y = matrix(2, rows=2, cols=3)
-X_val = matrix(3, rows=2, cols=3)
-Y_val = matrix(4, rows=2, cols=3)
-
-gradients = function (matrix[double] input) return (matrix[double] output) {
- output = input
-}
-
-aggregation = function (matrix[double] input) return (matrix[double] output) {
- output = input
-}
-
-e2 = "element2"
-hps = list(e2)
-
-# Use paramserv function
-paramsList2 = list(1, 2, 3)
-
-if (length(paramsList2) == 3) {
- paramsList2 = paramserv(model=paramsList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="gradients", agg="aggregation", mode="LOCAL", utype="BSP", freq="EPOCH", epochs=100, batchsize=64, k=7, scheme="DISJOINT_CONTIGUOUS", hyperparams=hps, checkpointing="NONE")
-}
-
-print(length(paramsList2))
\ No newline at end of file