You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2020/10/30 18:56:47 UTC
[systemds] branch master updated: [SYSTEMDS-2550] Federated
Parameter Server
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 428016c [SYSTEMDS-2550] Federated Parameter Server
428016c is described below
commit 428016c38fb55a3b6094334c7c876b71bf4bf3f7
Author: Tobias Rieger <to...@icloud.com>
AuthorDate: Tue Aug 25 10:23:00 2020 +0200
[SYSTEMDS-2550] Federated Parameter Server
This commit adds federated Parameter server to the system.
this allows for federated training of neural networks.
Particularly verified are two architectures, a feed forward NN and
a Convocational NN.
Closes #1075
---
.../apache/sysds/hops/recompile/Recompiler.java | 2 +-
.../java/org/apache/sysds/parser/Statement.java | 16 +-
.../controlprogram/caching/MatrixObject.java | 21 +-
.../federated/FederatedWorkerHandler.java | 6 +-
.../paramserv/FederatedPSControlThread.java | 560 +++++++++++++++++++++
.../controlprogram/paramserv/ParamServer.java | 5 +
.../controlprogram/paramserv/ParamservUtils.java | 47 +-
.../paramserv/dp/DataPartitionFederatedScheme.java | 88 ++++
.../paramserv/dp/FederatedDataPartitioner.java | 46 ++
.../dp/KeepDataOnWorkerFederatedScheme.java | 32 ++
.../paramserv/dp/ShuffleFederatedScheme.java | 33 ++
.../sysds/runtime/instructions/cp/ListObject.java | 145 +++++-
.../cp/ParamservBuiltinCPInstruction.java | 98 +++-
.../sysds/runtime/util/ProgramConverter.java | 40 +-
.../component/paramserv/SerializationTest.java | 86 +++-
.../paramserv/FederatedParamservTest.java | 195 +++++++
.../scripts/functions/federated/paramserv/CNN.dml | 474 +++++++++++++++++
.../federated/paramserv/FederatedParamservTest.dml | 57 +++
.../functions/federated/paramserv/TwoNN.dml | 299 +++++++++++
19 files changed, 2185 insertions(+), 65 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index d048863..c785cfc 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -1039,7 +1039,7 @@ public class Recompiler
}
}
- private static void rRecompileProgramBlock2Forced( ProgramBlock pb, long tid, HashSet<String> fnStack, ExecType et ) {
+ public static void rRecompileProgramBlock2Forced( ProgramBlock pb, long tid, HashSet<String> fnStack, ExecType et ) {
if (pb instanceof WhileProgramBlock)
{
WhileProgramBlock pbTmp = (WhileProgramBlock)pb;
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java b/src/main/java/org/apache/sysds/parser/Statement.java
index f0bdd66..b61b0d6 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -71,7 +71,7 @@ public abstract class Statement implements ParseInfo
public static final String PS_MODE = "mode";
public static final String PS_GRADIENTS = "gradients";
public enum PSModeType {
- LOCAL, REMOTE_SPARK
+ FEDERATED, LOCAL, REMOTE_SPARK
}
public static final String PS_UPDATE_TYPE = "utype";
public enum PSUpdateType {
@@ -94,12 +94,26 @@ public abstract class Statement implements ParseInfo
public enum PSScheme {
DISJOINT_CONTIGUOUS, DISJOINT_ROUND_ROBIN, DISJOINT_RANDOM, OVERLAP_RESHUFFLE
}
+ public enum FederatedPSScheme {
+ KEEP_DATA_ON_WORKER, SHUFFLE
+ }
public static final String PS_HYPER_PARAMS = "hyperparams";
public static final String PS_CHECKPOINTING = "checkpointing";
public enum PSCheckpointing {
NONE, EPOCH, EPOCH10
}
+ // String constants related to federated parameter server functionality
+ // prefixed with code: "1701-NCC-" to not overwrite anything
+ public static final String PS_FED_BATCH_SIZE = "1701-NCC-batch_size";
+ public static final String PS_FED_DATA_SIZE = "1701-NCC-data_size";
+ public static final String PS_FED_NUM_BATCHES = "1701-NCC-num_batches";
+ public static final String PS_FED_NAMESPACE = "1701-NCC-namespace";
+ public static final String PS_FED_GRADIENTS_FNAME = "1701-NCC-gradients_fname";
+ public static final String PS_FED_AGGREGATION_FNAME = "1701-NCC-aggregation_fname";
+ public static final String PS_FED_BATCHCOUNTER_VARID = "1701-NCC-batchcounter_varid";
+ public static final String PS_FED_MODEL_VARID = "1701-NCC-model_varid";
+
public abstract boolean controlStatement();
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index 85d8a8f..4fd2b06 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -66,7 +66,7 @@ import org.apache.sysds.runtime.util.IndexRange;
public class MatrixObject extends CacheableData<MatrixBlock>
{
private static final long serialVersionUID = 6374712373206495637L;
-
+
public enum UpdateType {
COPY,
INPLACE,
@@ -87,7 +87,7 @@ public class MatrixObject extends CacheableData<MatrixBlock>
private int _partitionSize = -1; //indicates n for BLOCKWISE_N
private String _partitionCacheName = null; //name of cache block
private MatrixBlock _partitionInMemory = null;
-
+
/**
* Constructor that takes the value type and the HDFS filename.
*
@@ -112,6 +112,23 @@ public class MatrixObject extends CacheableData<MatrixBlock>
_cache = null;
_data = null;
}
+
+ /**
+ * Constructor that takes the value type, HDFS filename and associated metadata and a MatrixBlock
+ * used for creation after serialization
+ *
+ * @param vt value type
+ * @param file file name
+ * @param mtd metadata
+ * @param data matrix block data
+ */
+ public MatrixObject( ValueType vt, String file, MetaData mtd, MatrixBlock data) {
+ super (DataType.MATRIX, vt);
+ _metaData = mtd;
+ _hdfsFileName = file;
+ _cache = null;
+ _data = data;
+ }
/**
* Copy constructor that copies meta data but NO data.
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 6764f12..e932785 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -244,11 +244,15 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
}
//wrap transferred cache block into cacheable data
- Data data = null;
+ Data data;
if( request.getParam(0) instanceof CacheBlock )
data = ExecutionContext.createCacheableData((CacheBlock) request.getParam(0));
else if( request.getParam(0) instanceof ScalarObject )
data = (ScalarObject) request.getParam(0);
+ else if( request.getParam(0) instanceof ListObject )
+ data = (ListObject) request.getParam(0);
+ else
+ throw new DMLRuntimeException("FederatedWorkerHandler: Unsupported object type, has to be of type CacheBlock or ScalarObject");
//set variable and construct empty response
ec.setVariable(varname, data);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
new file mode 100644
index 0000000..8fa0698
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -0,0 +1,560 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv;
+
+import org.apache.sysds.parser.DataIdentifier;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
+import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysds.runtime.controlprogram.ProgramBlock;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.IntObject;
+import org.apache.sysds.runtime.instructions.cp.ListObject;
+import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.util.ProgramConverter;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+import java.util.stream.Collectors;
+
+import static org.apache.sysds.runtime.util.ProgramConverter.*;
+
+public class FederatedPSControlThread extends PSWorker implements Callable<Void> {
+ FederatedData _featuresData;
+ FederatedData _labelsData;
+ final long _batchCounterVarID;
+ final long _modelVarID;
+ int _totalNumBatches;
+
+ public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) {
+ super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
+
+ // generate the IDs for model and batch counter. These get overwritten on the federated worker each time
+ _batchCounterVarID = FederationUtils.getNextFedDataID();
+ _modelVarID = FederationUtils.getNextFedDataID();
+ }
+
+ /**
+ * Sets up the federated worker and control thread
+ */
+ public void setup() {
+ // prepare features and labels
+ _features.getFedMapping().forEachParallel((range, data) -> {
+ _featuresData = data;
+ return null;
+ });
+ _labels.getFedMapping().forEachParallel((range, data) -> {
+ _labelsData = data;
+ return null;
+ });
+
+ // calculate number of batches and get data size
+ long dataSize = _features.getNumRows();
+ _totalNumBatches = (int) Math.ceil((double) dataSize / _batchSize);
+
+ // serialize program
+ // create program blocks for the instruction filtering
+ String programSerialized;
+ ArrayList<ProgramBlock> programBlocks = new ArrayList<>();
+
+ BasicProgramBlock gradientProgramBlock = new BasicProgramBlock(_ec.getProgram());
+ gradientProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_inst)));
+ programBlocks.add(gradientProgramBlock);
+
+ if(_freq == Statement.PSFrequency.EPOCH) {
+ BasicProgramBlock aggProgramBlock = new BasicProgramBlock(_ec.getProgram());
+ aggProgramBlock.setInstructions(new ArrayList<>(Arrays.asList(_ps.getAggInst())));
+ programBlocks.add(aggProgramBlock);
+ }
+
+ StringBuilder sb = new StringBuilder();
+ sb.append(PROG_BEGIN);
+ sb.append( NEWLINE );
+ sb.append(ProgramConverter.serializeProgram(_ec.getProgram(),
+ programBlocks,
+ new HashMap<>(),
+ false
+ ));
+ sb.append(PROG_END);
+ programSerialized = sb.toString();
+
+ // write program and meta data to worker
+ Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+ _featuresData.getVarID(),
+ new setupFederatedWorker(_batchSize,
+ dataSize,
+ _totalNumBatches,
+ programSerialized,
+ _inst.getNamespace(),
+ _inst.getFunctionName(),
+ _ps.getAggInst().getFunctionName(),
+ _ec.getListObject("hyperparams"),
+ _batchCounterVarID,
+ _modelVarID
+ )
+ ));
+
+ try {
+ FederatedResponse response = udfResponse.get();
+ if(!response.isSuccessful())
+ throw new DMLRuntimeException("FederatedLocalPSThread: Setup UDF failed");
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Setup UDF" + e.getMessage());
+ }
+ }
+
+ /**
+ * Setup UDF executed on the federated worker
+ */
+ private static class setupFederatedWorker extends FederatedUDF {
+ long _batchSize;
+ long _dataSize;
+ long _numBatches;
+ String _programString;
+ String _namespace;
+ String _gradientsFunctionName;
+ String _aggregationFunctionName;
+ ListObject _hyperParams;
+ long _batchCounterVarID;
+ long _modelVarID;
+
+ protected setupFederatedWorker(long batchSize, long dataSize, long numBatches, String programString, String namespace, String gradientsFunctionName, String aggregationFunctionName, ListObject hyperParams, long batchCounterVarID, long modelVarID) {
+ super(new long[]{});
+ _batchSize = batchSize;
+ _dataSize = dataSize;
+ _numBatches = numBatches;
+ _programString = programString;
+ _namespace = namespace;
+ _gradientsFunctionName = gradientsFunctionName;
+ _aggregationFunctionName = aggregationFunctionName;
+ _hyperParams = hyperParams;
+ _batchCounterVarID = batchCounterVarID;
+ _modelVarID = modelVarID;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ // parse and set program
+ ec.setProgram(ProgramConverter.parseProgram(_programString, 0, false));
+
+ // set variables to ec
+ ec.setVariable(Statement.PS_FED_BATCH_SIZE, new IntObject(_batchSize));
+ ec.setVariable(Statement.PS_FED_DATA_SIZE, new IntObject(_dataSize));
+ ec.setVariable(Statement.PS_FED_NUM_BATCHES, new IntObject(_numBatches));
+ ec.setVariable(Statement.PS_FED_NAMESPACE, new StringObject(_namespace));
+ ec.setVariable(Statement.PS_FED_GRADIENTS_FNAME, new StringObject(_gradientsFunctionName));
+ ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new StringObject(_aggregationFunctionName));
+ ec.setVariable(Statement.PS_HYPER_PARAMS, _hyperParams);
+ ec.setVariable(Statement.PS_FED_BATCHCOUNTER_VARID, new IntObject(_batchCounterVarID));
+ ec.setVariable(Statement.PS_FED_MODEL_VARID, new IntObject(_modelVarID));
+
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+ }
+ }
+
+ /**
+ * cleans up the execution context of the federated worker
+ */
+ public void teardown() {
+ // write program and meta data to worker
+ Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+ _featuresData.getVarID(),
+ new teardownFederatedWorker()
+ ));
+
+ try {
+ FederatedResponse response = udfResponse.get();
+ if(!response.isSuccessful())
+ throw new DMLRuntimeException("FederatedLocalPSThread: Teardown UDF failed");
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute Teardown UDF" + e.getMessage());
+ }
+ }
+
+ /**
+ * Teardown UDF executed on the federated worker
+ */
+ private static class teardownFederatedWorker extends FederatedUDF {
+ protected teardownFederatedWorker() {
+ super(new long[]{});
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ // remove variables from ec
+ ec.removeVariable(Statement.PS_FED_BATCH_SIZE);
+ ec.removeVariable(Statement.PS_FED_DATA_SIZE);
+ ec.removeVariable(Statement.PS_FED_NUM_BATCHES);
+ ec.removeVariable(Statement.PS_FED_NAMESPACE);
+ ec.removeVariable(Statement.PS_FED_GRADIENTS_FNAME);
+ ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME);
+ ec.removeVariable(Statement.PS_FED_BATCHCOUNTER_VARID);
+ ec.removeVariable(Statement.PS_FED_MODEL_VARID);
+ ParamservUtils.cleanupListObject(ec, Statement.PS_HYPER_PARAMS);
+ ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
+
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+ }
+ }
+
+ /**
+ * Entry point of the functionality
+ *
+ * @return void
+ * @throws Exception incase the execution fails
+ */
+ @Override
+ public Void call() throws Exception {
+ try {
+ switch (_freq) {
+ case BATCH:
+ computeBatch(_totalNumBatches);
+ break;
+ case EPOCH:
+ computeEpoch();
+ break;
+ default:
+ throw new DMLRuntimeException(String.format("%s not support update frequency %s", getWorkerName(), _freq));
+ }
+ } catch (Exception e) {
+ throw new DMLRuntimeException(String.format("%s failed", getWorkerName()), e);
+ }
+ teardown();
+ return null;
+ }
+
+ protected ListObject pullModel() {
+ // Pull the global parameters from ps
+ return _ps.pull(_workerID);
+ }
+
+ protected void pushGradients(ListObject gradients) {
+ // Push the gradients to ps
+ _ps.push(_workerID, gradients);
+ }
+
+ /**
+ * Computes all epochs and synchronizes after each batch
+ *
+ * @param numBatches the number of batches per epoch
+ */
+ protected void computeBatch(int numBatches) {
+ for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) {
+ for (int batchCounter = 0; batchCounter < numBatches; batchCounter++) {
+ ListObject model = pullModel();
+ ListObject gradients = computeBatchGradients(model, batchCounter);
+ pushGradients(gradients);
+ ParamservUtils.cleanupListObject(model);
+ ParamservUtils.cleanupListObject(gradients);
+ }
+ System.out.println("[+] " + this.getWorkerName() + " completed epoch " + epochCounter);
+ }
+ }
+
+ /**
+ * Computes a single specified batch on the federated worker
+ *
+ * @param model the current model from the parameter server
+ * @param batchCounter the current batch number needed for slicing the features and labels
+ * @return the gradient vector
+ */
+ protected ListObject computeBatchGradients(ListObject model, int batchCounter) {
+ // put batch counter on federated worker
+ Future<FederatedResponse> putBatchCounterResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _batchCounterVarID, new IntObject(batchCounter)));
+
+ // put current model on federated worker
+ Future<FederatedResponse> putParamsResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model));
+
+ try {
+ if(!putParamsResponse.get().isSuccessful() || !putBatchCounterResponse.get().isSuccessful())
+ throw new DMLRuntimeException("FederatedLocalPSThread: put was not successful");
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute put" + e.getMessage());
+ }
+
+ // create and execute the udf on the remote worker
+ Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+ _featuresData.getVarID(),
+ new federatedComputeBatchGradients(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), _batchCounterVarID, _modelVarID})
+ ));
+
+ try {
+ Object[] responseData = udfResponse.get().getData();
+ return (ListObject) responseData[0];
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute UDF" + e.getMessage());
+ }
+ }
+
+ /**
+ * This is the code that will be executed on the federated Worker when computing a single batch
+ */
+ private static class federatedComputeBatchGradients extends FederatedUDF {
+ protected federatedComputeBatchGradients(long[] inIDs) {
+ super(inIDs);
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ // read in data by varid
+ MatrixObject features = (MatrixObject) data[0];
+ MatrixObject labels = (MatrixObject) data[1];
+ long batchCounter = ((IntObject) data[2]).getLongValue();
+ ListObject model = (ListObject) data[3];
+
+ // get data from execution context
+ long batchSize = ((IntObject) ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue();
+ long dataSize = ((IntObject) ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue();
+ String namespace = ((StringObject) ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue();
+ String gradientsFunctionName = ((StringObject) ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
+
+ // slice batch from feature and label matrix
+ long begin = batchCounter * batchSize + 1;
+ long end = Math.min((batchCounter + 1) * batchSize, dataSize);
+ MatrixObject bFeatures = ParamservUtils.sliceMatrix(features, begin, end);
+ MatrixObject bLabels = ParamservUtils.sliceMatrix(labels, begin, end);
+
+ // prepare execution context
+ ec.setVariable(Statement.PS_MODEL, model);
+ ec.setVariable(Statement.PS_FEATURES, bFeatures);
+ ec.setVariable(Statement.PS_LABELS, bLabels);
+
+ // recreate gradient instruction and output
+ FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunctionName, false);
+ ArrayList<DataIdentifier> inputs = func.getInputParams();
+ ArrayList<DataIdentifier> outputs = func.getOutputParams();
+ CPOperand[] boundInputs = inputs.stream()
+ .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+ .toArray(CPOperand[]::new);
+ ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
+ .collect(Collectors.toCollection(ArrayList::new));
+ Instruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
+ func.getInputParamNames(), outputNames, "gradient function");
+ DataIdentifier gradientsOutput = outputs.get(0);
+
+ // calculate and gradients
+ gradientsInstruction.processInstruction(ec);
+ ListObject gradients = ec.getListObject(gradientsOutput.getName());
+
+ // clean up sliced batch
+ ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
+ ParamservUtils.cleanupData(ec, Statement.PS_FEATURES);
+ ParamservUtils.cleanupData(ec, Statement.PS_LABELS);
+
+ // model clean up - doing this twice is not an issue
+ ParamservUtils.cleanupListObject(ec, ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
+ ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL);
+
+ // return
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, gradients);
+ }
+ }
+
+ /**
+ * Computes all epochs and synchronizes after each one
+ */
+ protected void computeEpoch() {
+ for (int epochCounter = 0; epochCounter < _epochs; epochCounter++) {
+ // Pull the global parameters from ps
+ ListObject model = pullModel();
+ ListObject gradients = computeEpochGradients(model);
+ pushGradients(gradients);
+ System.out.println("[+] " + this.getWorkerName() + " completed epoch " + epochCounter);
+ ParamservUtils.cleanupListObject(model);
+ ParamservUtils.cleanupListObject(gradients);
+ }
+ }
+
+ /**
+ * Computes one epoch on the federated worker and updates the model local
+ *
+ * @param model the current model from the parameter server
+ * @return the gradient vector
+ */
+ protected ListObject computeEpochGradients(ListObject model) {
+ // put current model on federated worker
+ Future<FederatedResponse> putParamsResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model));
+
+ try {
+ if(!putParamsResponse.get().isSuccessful())
+ throw new DMLRuntimeException("FederatedLocalPSThread: put was not successful");
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute put" + e.getMessage());
+ }
+
+ // create and execute the udf on the remote worker
+ Future<FederatedResponse> udfResponse = _featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+ _featuresData.getVarID(),
+ new federatedComputeEpochGradients(new long[]{_featuresData.getVarID(), _labelsData.getVarID(), _modelVarID})
+ ));
+
+ try {
+ Object[] responseData = udfResponse.get().getData();
+ return (ListObject) responseData[0];
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException("FederatedLocalPSThread: failed to execute UDF" + e.getMessage());
+ }
+ }
+
+ /**
+ * This is the code that will be executed on the federated Worker when computing one epoch
+ */
+ private static class federatedComputeEpochGradients extends FederatedUDF {
+ protected federatedComputeEpochGradients(long[] inIDs) {
+ super(inIDs);
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ // read in data by varid
+ MatrixObject features = (MatrixObject) data[0];
+ MatrixObject labels = (MatrixObject) data[1];
+ ListObject model = (ListObject) data[2];
+
+ // get data from execution context
+ long batchSize = ((IntObject) ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue();
+ long dataSize = ((IntObject) ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue();
+ long numBatches = ((IntObject) ec.getVariable(Statement.PS_FED_NUM_BATCHES)).getLongValue();
+ String namespace = ((StringObject) ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue();
+ String gradientsFunctionName = ((StringObject) ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
+ String aggregationFuctionName = ((StringObject) ec.getVariable(Statement.PS_FED_AGGREGATION_FNAME)).getStringValue();
+
+ // recreate gradient instruction and output
+ FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunctionName, false);
+ ArrayList<DataIdentifier> inputs = func.getInputParams();
+ ArrayList<DataIdentifier> outputs = func.getOutputParams();
+ CPOperand[] boundInputs = inputs.stream()
+ .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+ .toArray(CPOperand[]::new);
+ ArrayList<String> outputNames = outputs.stream().map(DataIdentifier::getName)
+ .collect(Collectors.toCollection(ArrayList::new));
+ Instruction gradientsInstruction = new FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
+ func.getInputParamNames(), outputNames, "gradient function");
+ DataIdentifier gradientsOutput = outputs.get(0);
+
+ // recreate aggregation instruction and output
+ func = ec.getProgram().getFunctionProgramBlock(namespace, aggregationFuctionName, false);
+ inputs = func.getInputParams();
+ outputs = func.getOutputParams();
+ boundInputs = inputs.stream()
+ .map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+ .toArray(CPOperand[]::new);
+ outputNames = outputs.stream().map(DataIdentifier::getName)
+ .collect(Collectors.toCollection(ArrayList::new));
+ Instruction aggregationInstruction = new FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs,
+ func.getInputParamNames(), outputNames, "aggregation function");
+ DataIdentifier aggregationOutput = outputs.get(0);
+
+
+ ListObject accGradients = null;
+ // prepare execution context
+ ec.setVariable(Statement.PS_MODEL, model);
+ for (int batchCounter = 0; batchCounter < numBatches; batchCounter++) {
+ // slice batch from feature and label matrix
+ long begin = batchCounter * batchSize + 1;
+ long end = Math.min((batchCounter + 1) * batchSize, dataSize);
+ MatrixObject bFeatures = ParamservUtils.sliceMatrix(features, begin, end);
+ MatrixObject bLabels = ParamservUtils.sliceMatrix(labels, begin, end);
+
+ // prepare execution context
+ ec.setVariable(Statement.PS_FEATURES, bFeatures);
+ ec.setVariable(Statement.PS_LABELS, bLabels);
+ boolean localUpdate = batchCounter < numBatches - 1;
+
+ // calculate intermediate gradients
+ gradientsInstruction.processInstruction(ec);
+ ListObject gradients = ec.getListObject(gradientsOutput.getName());
+
+ // TODO: is this equivalent for momentum based and AMS prob?
+ accGradients = ParamservUtils.accrueGradients(accGradients, gradients, false);
+
+ // Update the local model with gradients
+ if(localUpdate) {
+ // Invoke the aggregate function
+ aggregationInstruction.processInstruction(ec);
+ // Get the new model
+ model = ec.getListObject(aggregationOutput.getName());
+ // Set new model in execution context
+ ec.setVariable(Statement.PS_MODEL, model);
+ // clean up gradients and result
+ ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
+ ParamservUtils.cleanupListObject(ec, aggregationOutput.getName());
+ }
+
+ // clean up sliced batch
+ ParamservUtils.cleanupData(ec, Statement.PS_FEATURES);
+ ParamservUtils.cleanupData(ec, Statement.PS_LABELS);
+ }
+
+ // model clean up - doing this twice is not an issue
+ ParamservUtils.cleanupListObject(ec, ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
+ ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL);
+
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, accGradients);
+ }
+ }
+
+ // Statistics methods
+ @Override
+ public String getWorkerName() {
+ return String.format("Federated worker_%d", _workerID);
+ }
+
+ @Override
+ protected void incWorkerNumber() {
+
+ }
+
+ @Override
+ protected void accLocalModelUpdateTime(Timing time) {
+
+ }
+
+ @Override
+ protected void accBatchIndexingTime(Timing time) {
+
+ }
+
+ @Override
+ protected void accGradientComputeTime(Timing time) {
+
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 276f56c..e420ed8 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -57,6 +57,7 @@ public abstract class ParamServer
//aggregation service
protected ExecutionContext _ec;
private Statement.PSUpdateType _updateType;
+
private FunctionCallCPInstruction _inst;
private String _outputName;
private boolean[] _finishedStates; // Workers' finished states
@@ -232,4 +233,8 @@ public abstract class ParamServer
if (DMLScript.STATISTICS)
Statistics.accPSModelBroadcastTime((long) tBroad.stop());
}
+
+ public FunctionCallCPInstruction getAggInst() {
+ return _inst;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index 968cb1d..e63fb14 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -31,6 +31,7 @@ import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.MultiThreadedHop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.lops.LopProperties;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.parser.Statement;
@@ -214,15 +215,21 @@ public class ParamservUtils {
}
public static ExecutionContext createExecutionContext(ExecutionContext ec,
- LocalVariableMap varsMap, String updFunc, String aggFunc, int k)
+ LocalVariableMap varsMap, String updFunc, String aggFunc, int k)
+ {
+ return createExecutionContext(ec, varsMap, updFunc, aggFunc, k, false);
+ }
+
+ public static ExecutionContext createExecutionContext(ExecutionContext ec,
+ LocalVariableMap varsMap, String updFunc, String aggFunc, int k, boolean forceExecTypeCP)
{
Program prog = ec.getProgram();
// 1. Recompile the internal program blocks
- recompileProgramBlocks(k, prog.getProgramBlocks());
+ recompileProgramBlocks(k, prog.getProgramBlocks(), forceExecTypeCP);
// 2. Recompile the imported function blocks
prog.getFunctionProgramBlocks(false)
- .forEach((fname, fvalue) -> recompileProgramBlocks(k, fvalue.getChildBlocks()));
+ .forEach((fname, fvalue) -> recompileProgramBlocks(k, fvalue.getChildBlocks(), forceExecTypeCP));
// 3. Copy all functions
return ExecutionContextFactory.createContext(
@@ -249,6 +256,10 @@ public class ParamservUtils {
}
public static void recompileProgramBlocks(int k, List<ProgramBlock> pbs) {
+ recompileProgramBlocks(k, pbs, false);
+ }
+
+ public static void recompileProgramBlocks(int k, List<ProgramBlock> pbs, boolean forceExecTypeCP) {
// Reset the visit status from root
for (ProgramBlock pb : pbs)
DMLTranslator.resetHopsDAGVisitStatus(pb.getStatementBlock());
@@ -256,43 +267,49 @@ public class ParamservUtils {
// Should recursively assign the level of parallelism
// and recompile the program block
try {
- rAssignParallelism(pbs, k, false);
+ if(forceExecTypeCP)
+ rAssignParallelismAndRecompile(pbs, k, true, forceExecTypeCP);
+ else
+ rAssignParallelismAndRecompile(pbs, k, false, forceExecTypeCP);
} catch (IOException e) {
throw new DMLRuntimeException(e);
}
}
- private static boolean rAssignParallelism(List<ProgramBlock> pbs, int k, boolean recompiled) throws IOException {
+ private static boolean rAssignParallelismAndRecompile(List<ProgramBlock> pbs, int k, boolean recompiled, boolean forceExecTypeCP) throws IOException {
for (ProgramBlock pb : pbs) {
if (pb instanceof ParForProgramBlock) {
ParForProgramBlock pfpb = (ParForProgramBlock) pb;
pfpb.setDegreeOfParallelism(k);
- recompiled |= rAssignParallelism(pfpb.getChildBlocks(), 1, recompiled);
+ recompiled |= rAssignParallelismAndRecompile(pfpb.getChildBlocks(), 1, recompiled, forceExecTypeCP);
} else if (pb instanceof ForProgramBlock) {
- recompiled |= rAssignParallelism(((ForProgramBlock) pb).getChildBlocks(), k, recompiled);
+ recompiled |= rAssignParallelismAndRecompile(((ForProgramBlock) pb).getChildBlocks(), k, recompiled, forceExecTypeCP);
} else if (pb instanceof WhileProgramBlock) {
- recompiled |= rAssignParallelism(((WhileProgramBlock) pb).getChildBlocks(), k, recompiled);
+ recompiled |= rAssignParallelismAndRecompile(((WhileProgramBlock) pb).getChildBlocks(), k, recompiled, forceExecTypeCP);
} else if (pb instanceof FunctionProgramBlock) {
- recompiled |= rAssignParallelism(((FunctionProgramBlock) pb).getChildBlocks(), k, recompiled);
+ recompiled |= rAssignParallelismAndRecompile(((FunctionProgramBlock) pb).getChildBlocks(), k, recompiled, forceExecTypeCP);
} else if (pb instanceof IfProgramBlock) {
IfProgramBlock ipb = (IfProgramBlock) pb;
- recompiled |= rAssignParallelism(ipb.getChildBlocksIfBody(), k, recompiled);
+ recompiled |= rAssignParallelismAndRecompile(ipb.getChildBlocksIfBody(), k, recompiled, forceExecTypeCP);
if (ipb.getChildBlocksElseBody() != null)
- recompiled |= rAssignParallelism(ipb.getChildBlocksElseBody(), k, recompiled);
+ recompiled |= rAssignParallelismAndRecompile(ipb.getChildBlocksElseBody(), k, recompiled, forceExecTypeCP);
} else {
StatementBlock sb = pb.getStatementBlock();
for (Hop hop : sb.getHops())
- recompiled |= rAssignParallelism(hop, k, recompiled);
+ recompiled |= rAssignParallelismAndRecompile(hop, k, recompiled);
}
// Recompile the program block
if (recompiled) {
- Recompiler.recompileProgramBlockInstructions(pb);
+ if(forceExecTypeCP)
+ Recompiler.rRecompileProgramBlock2Forced(pb, pb.getThreadID(), new HashSet<>(), LopProperties.ExecType.CP);
+ else
+ Recompiler.recompileProgramBlockInstructions(pb);
}
}
return recompiled;
}
- private static boolean rAssignParallelism(Hop hop, int k, boolean recompiled) {
+ private static boolean rAssignParallelismAndRecompile(Hop hop, int k, boolean recompiled) {
if (hop.isVisited()) {
return recompiled;
}
@@ -304,7 +321,7 @@ public class ParamservUtils {
}
ArrayList<Hop> inputs = hop.getInput();
for (Hop h : inputs) {
- recompiled |= rAssignParallelism(h, k, recompiled);
+ recompiled |= rAssignParallelismAndRecompile(h, k, recompiled);
}
hop.setVisited();
return recompiled;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
new file mode 100644
index 0000000..4183372
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.lops.compile.Dag;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+
+public abstract class DataPartitionFederatedScheme {
+
+ public static final class Result {
+ public final List<MatrixObject> pFeatures;
+ public final List<MatrixObject> pLabels;
+ public final int workerNum;
+
+ public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels, int workerNum) {
+ this.pFeatures = pFeatures;
+ this.pLabels = pLabels;
+ this.workerNum = workerNum;
+ }
+ }
+
+ public abstract Result doPartitioning(MatrixObject features, MatrixObject labels);
+
+ /**
+ * Takes a row federated Matrix and slices it into a matrix for each worker
+ *
+ * @param fedMatrix the federated input matrix
+ */
+ static List<MatrixObject> sliceFederatedMatrix(MatrixObject fedMatrix) {
+ if (fedMatrix.isFederated(FederationMap.FType.ROW)) {
+
+ List<MatrixObject> slices = Collections.synchronizedList(new ArrayList<>());
+ fedMatrix.getFedMapping().forEachParallel((range, data) -> {
+ // Create sliced matrix object
+ MatrixObject slice = new MatrixObject(fedMatrix.getValueType(), Dag.getNextUniqueVarname(Types.DataType.MATRIX));
+ // Warning needs MetaDataFormat instead of MetaData
+ slice.setMetaData(new MetaDataFormat(
+ new MatrixCharacteristics(range.getSize(0), range.getSize(1)),
+ Types.FileFormat.BINARY)
+ );
+
+ // Create new federation map
+ HashMap<FederatedRange, FederatedData> newFedHashMap = new HashMap<>();
+ newFedHashMap.put(range, data);
+ slice.setFedMapping(new FederationMap(fedMatrix.getFedMapping().getID(), newFedHashMap));
+ slice.getFedMapping().setType(FederationMap.FType.ROW);
+
+ slices.add(slice);
+ return null;
+ });
+
+ return slices;
+ }
+ else {
+ throw new DMLRuntimeException("Federated data partitioner: " +
+ "currently only supports row federated data");
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
new file mode 100644
index 0000000..4cdfb95
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+
+public class FederatedDataPartitioner {
+
+ private final DataPartitionFederatedScheme _scheme;
+
+ public FederatedDataPartitioner(Statement.FederatedPSScheme scheme) {
+ switch (scheme) {
+ case KEEP_DATA_ON_WORKER:
+ _scheme = new KeepDataOnWorkerFederatedScheme();
+ break;
+ case SHUFFLE:
+ _scheme = new ShuffleFederatedScheme();
+ break;
+ default:
+ throw new DMLRuntimeException(String.format("FederatedDataPartitioner: not support data partition scheme '%s'", scheme));
+ }
+ }
+
+ public DataPartitionFederatedScheme.Result doPartitioning(MatrixObject features, MatrixObject labels) {
+ return _scheme.doPartitioning(features, labels);
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
new file mode 100644
index 0000000..06feded
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import java.util.List;
+
+public class KeepDataOnWorkerFederatedScheme extends DataPartitionFederatedScheme {
+ @Override
+ public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+ List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
+ List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+ return new Result(pFeatures, pLabels, pFeatures.size());
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
new file mode 100644
index 0000000..d6d8cfc
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+
+import java.util.List;
+
+public class ShuffleFederatedScheme extends DataPartitionFederatedScheme {
+ @Override
+ public Result doPartitioning(MatrixObject features, MatrixObject labels) {
+ List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
+ List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+ return new Result(pFeatures, pLabels, pFeatures.size());
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index 4f726ee..a8397cb 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -19,17 +19,29 @@
package org.apache.sysds.runtime.instructions.cp;
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
-public class ListObject extends Data {
+public class ListObject extends Data implements Externalizable {
private static final long serialVersionUID = 3652422061598967358L;
private final List<Data> _data;
@@ -37,6 +49,14 @@ public class ListObject extends Data {
private List<String> _names = null;
private int _nCacheable;
private List<LineageItem> _lineage = null;
+
+ /*
+ * No op constructor for Externalizable interface
+ */
+ public ListObject() {
+ super(DataType.LIST, ValueType.UNKNOWN);
+ _data = new ArrayList<>();
+ }
public ListObject(List<Data> data) {
this(data, null, null);
@@ -286,4 +306,127 @@ public class ListObject extends Data {
sb.append(")");
return sb.toString();
}
+
+ /**
+ * Redirects the default java serialization via externalizable to our default
+ * hadoop writable serialization for efficient broadcast/rdd serialization.
+ *
+ * @param out object output
+ * @throws IOException if IOException occurs
+ */
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ // write out length
+ out.writeInt(getLength());
+ // write out num cacheable
+ out.writeInt(_nCacheable);
+
+ // write out names for named list
+ out.writeBoolean(getNames() != null);
+ if(getNames() != null) {
+ for (int i = 0; i < getLength(); i++) {
+ out.writeObject(_names.get(i));
+ }
+ }
+
+ // write out data
+ for(int i = 0; i < getLength(); i++) {
+ Data d = getData(i);
+ out.writeObject(d.getDataType());
+ out.writeObject(d.getValueType());
+ out.writeObject(d.getPrivacyConstraint());
+ switch(d.getDataType()) {
+ case LIST:
+ ListObject lo = (ListObject) d;
+ out.writeObject(lo);
+ break;
+ case MATRIX:
+ MatrixObject mo = (MatrixObject) d;
+ MetaDataFormat md = (MetaDataFormat) mo.getMetaData();
+ DataCharacteristics dc = md.getDataCharacteristics();
+
+ out.writeObject(dc.getRows());
+ out.writeObject(dc.getCols());
+ out.writeObject(dc.getBlocksize());
+ out.writeObject(dc.getNonZeros());
+ out.writeObject(md.getFileFormat());
+ out.writeObject(mo.acquireReadAndRelease());
+ break;
+ case SCALAR:
+ ScalarObject so = (ScalarObject) d;
+ out.writeObject(so.getStringValue());
+ break;
+ default:
+ throw new DMLRuntimeException("Unable to serialize datatype " + dataType);
+ }
+ }
+ }
+
+ /**
+ * Redirects the default java serialization via externalizable to our default
+ * hadoop writable serialization for efficient broadcast/rdd deserialization.
+ *
+ * @param in object input
+ * @throws IOException if IOException occurs
+ */
+ @Override
+ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
+ // read in length
+ int length = in.readInt();
+ // read in num cacheable
+ _nCacheable = in.readInt();
+
+ // read in names
+ Boolean names = in.readBoolean();
+ if(names) {
+ _names = new ArrayList<>();
+ for (int i = 0; i < length; i++) {
+ _names.add((String) in.readObject());
+ }
+ }
+
+ // read in data
+ for(int i = 0; i < length; i++) {
+ DataType dataType = (DataType) in.readObject();
+ ValueType valueType = (ValueType) in.readObject();
+ PrivacyConstraint privacyConstraint = (PrivacyConstraint) in.readObject();
+ Data d;
+ switch(dataType) {
+ case LIST:
+ d = (ListObject) in.readObject();
+ break;
+ case MATRIX:
+ long rows = (long) in.readObject();
+ long cols = (long) in.readObject();
+ int blockSize = (int) in.readObject();
+ long nonZeros = (long) in.readObject();
+ Types.FileFormat fileFormat = (Types.FileFormat) in.readObject();
+
+ // construct objects and set meta data
+ MatrixCharacteristics matrixCharacteristics = new MatrixCharacteristics(rows, cols, blockSize, nonZeros);
+ MetaDataFormat metaDataFormat = new MetaDataFormat(matrixCharacteristics, fileFormat);
+ MatrixBlock matrixBlock = (MatrixBlock) in.readObject();
+
+ d = new MatrixObject(valueType, Dag.getNextUniqueVarname(Types.DataType.MATRIX), metaDataFormat, matrixBlock);
+ break;
+ case SCALAR:
+ String value = (String) in.readObject();
+ ScalarObject so;
+ switch (valueType) {
+ case INT64: so = new IntObject(Long.parseLong(value)); break;
+ case FP64: so = new DoubleObject(Double.parseDouble(value)); break;
+ case BOOLEAN: so = new BooleanObject(Boolean.parseBoolean(value)); break;
+ case STRING: so = new StringObject(value); break;
+ default:
+ throw new DMLRuntimeException("Unable to parse valuetype " + valueType);
+ }
+ d = so;
+ break;
+ default:
+ throw new DMLRuntimeException("Unable to deserialize datatype " + dataType);
+ }
+ d.setPrivacyConstraints(privacyConstraint);
+ _data.add(d);
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index c42ec91..5e8ad32 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -52,6 +52,7 @@ import org.apache.spark.util.LongAccumulator;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.Statement.PSFrequency;
import org.apache.sysds.parser.Statement.PSModeType;
import org.apache.sysds.parser.Statement.PSScheme;
@@ -61,13 +62,16 @@ import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.controlprogram.paramserv.FederatedPSControlThread;
import org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker;
import org.apache.sysds.runtime.controlprogram.paramserv.LocalParamServer;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamServer;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSBody;
import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSWorker;
+import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
+import org.apache.sysds.runtime.controlprogram.paramserv.dp.FederatedDataPartitioner;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.LocalDataPartitioner;
import org.apache.sysds.runtime.controlprogram.paramserv.rpc.PSRpcFactory;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
@@ -91,16 +95,87 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
@Override
public void processInstruction(ExecutionContext ec) {
- PSModeType mode = getPSMode();
- switch (mode) {
- case LOCAL:
- runLocally(ec, mode);
- break;
- case REMOTE_SPARK:
- runOnSpark((SparkExecutionContext) ec, mode);
- break;
- default:
- throw new DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode));
+ // check if the input is federated
+ if(ec.getMatrixObject(getParam(PS_FEATURES)).isFederated() ||
+ ec.getMatrixObject(getParam(PS_LABELS)).isFederated()) {
+ runFederated(ec);
+ }
+ // if not federated check mode
+ else {
+ PSModeType mode = getPSMode();
+ switch (mode) {
+ case LOCAL:
+ runLocally(ec, mode);
+ break;
+ case REMOTE_SPARK:
+ runOnSpark((SparkExecutionContext) ec, mode);
+ break;
+ default:
+ throw new DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode));
+ }
+ }
+ }
+
+ private void runFederated(ExecutionContext ec) {
+ System.out.println("PARAMETER SERVER");
+ System.out.println("[+] Running in federated mode");
+
+ // get inputs
+ PSFrequency freq = getFrequency();
+ PSUpdateType updateType = getUpdateType();
+ String updFunc = getParam(PS_UPDATE_FUN);
+ String aggFunc = getParam(PS_AGGREGATION_FUN);
+
+ // partition federated data
+ DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(Statement.FederatedPSScheme.KEEP_DATA_ON_WORKER)
+ .doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), ec.getMatrixObject(getParam(PS_LABELS)));
+ List<MatrixObject> pFeatures = result.pFeatures;
+ List<MatrixObject> pLabels = result.pLabels;
+ int workerNum = result.workerNum;
+
+ // setup threading
+ BasicThreadFactory factory = new BasicThreadFactory.Builder()
+ .namingPattern("workers-pool-thread-%d").build();
+ ExecutorService es = Executors.newFixedThreadPool(workerNum, factory);
+
+ // Get the compiled execution context
+ LocalVariableMap newVarsMap = createVarsMap(ec);
+ // Level of par is 1 because one worker will be launched per task
+ // TODO: Fix recompilation
+ ExecutionContext newEC = ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, 1, true);
+ // Create workers' execution context
+ List<ExecutionContext> federatedWorkerECs = ParamservUtils.copyExecutionContext(newEC, workerNum);
+ // Create the agg service's execution context
+ ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0);
+ // Create the parameter server
+ ListObject model = ec.getListObject(getParam(PS_MODEL));
+ ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, updateType, workerNum, model, aggServiceEC);
+ // Create the local workers
+ List<FederatedPSControlThread> threads = IntStream.range(0, workerNum)
+ .mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, getEpochs(), getBatchSize(), federatedWorkerECs.get(i), ps))
+ .collect(Collectors.toList());
+
+ if(workerNum != threads.size()) {
+ throw new DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning does not match threads!");
+ }
+
+ // Set features and lables for the control threads and write the program and instructions and hyperparams to the federated workers
+ for (int i = 0; i < threads.size(); i++) {
+ threads.get(i).setFeatures(pFeatures.get(i));
+ threads.get(i).setLabels(pLabels.get(i));
+ threads.get(i).setup();
+ }
+
+ try {
+ // Launch the worker threads and wait for completion
+ for (Future<Void> ret : es.invokeAll(threads))
+ ret.get(); //error handling
+ // Fetch the final model from ps
+ ec.setVariable(output.getName(), ps.getResult());
+ } catch (InterruptedException | ExecutionException e) {
+ throw new DMLRuntimeException("ParamservBuiltinCPInstruction: unknown error: ", e);
+ } finally {
+ es.shutdownNow();
}
}
@@ -150,7 +225,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
LongAccumulator aEpoch = sec.getSparkContext().sc().longAccumulator("numEpochs");
// Create remote workers
- SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN),
+ SparkPSWorker worker = new SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN),
getFrequency(), getEpochs(), getBatchSize(), program, clsMap, sec.getSparkContext().getConf(),
server.getPort(), aSetup, aWorker, aUpdate, aIndex, aGrad, aRPC, aBatch, aEpoch);
@@ -333,6 +408,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
*/
private static ParamServer createPS(PSModeType mode, String aggFunc, PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) {
switch (mode) {
+ case FEDERATED:
case LOCAL:
case REMOTE_SPARK:
return LocalParamServer.create(model, aggFunc, updateType, ec, workerNum);
diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
index 7dad319..16472c7 100644
--- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
@@ -817,7 +817,7 @@ public class ProgramConverter
//handle program
sb.append(PROG_BEGIN);
sb.append( NEWLINE );
- sb.append( serializeProgram(prog, pbs, clsMap) );
+ sb.append( serializeProgram(prog, pbs, clsMap, true) );
sb.append(PROG_END);
sb.append( NEWLINE );
sb.append( COMPONENTS_DELIM );
@@ -849,32 +849,32 @@ public class ProgramConverter
return sb.toString();
}
- private static String serializeProgram( Program prog, ArrayList<ProgramBlock> pbs, HashMap<String, byte[]> clsMap ) {
- //note program contains variables, programblocks and function program blocks
+ public static String serializeProgram( Program prog, ArrayList<ProgramBlock> pbs, HashMap<String, byte[]> clsMap, boolean opt) {
+ //note program contains variables, programblocks and function program blocks
//but in order to avoid redundancy, we only serialize function program blocks
- HashMap<String, FunctionProgramBlock> fpb = prog.getFunctionProgramBlocks();
+ HashMap<String, FunctionProgramBlock> fpb = prog.getFunctionProgramBlocks(opt);
HashSet<String> cand = new HashSet<>();
- rFindSerializationCandidates(pbs, cand);
- return rSerializeFunctionProgramBlocks( fpb, cand, clsMap );
+ rFindSerializationCandidates(pbs, cand, opt);
+ return rSerializeFunctionProgramBlocks(fpb, cand, clsMap);
}
- private static void rFindSerializationCandidates( ArrayList<ProgramBlock> pbs, HashSet<String> cand )
+ private static void rFindSerializationCandidates( ArrayList<ProgramBlock> pbs, HashSet<String> cand, boolean opt)
{
for( ProgramBlock pb : pbs )
{
if( pb instanceof WhileProgramBlock ) {
WhileProgramBlock wpb = (WhileProgramBlock) pb;
- rFindSerializationCandidates(wpb.getChildBlocks(), cand );
+ rFindSerializationCandidates(wpb.getChildBlocks(), cand, opt);
}
else if ( pb instanceof ForProgramBlock || pb instanceof ParForProgramBlock ) {
ForProgramBlock fpb = (ForProgramBlock) pb;
- rFindSerializationCandidates(fpb.getChildBlocks(), cand);
+ rFindSerializationCandidates(fpb.getChildBlocks(), cand, opt);
}
else if ( pb instanceof IfProgramBlock ) {
IfProgramBlock ipb = (IfProgramBlock) pb;
- rFindSerializationCandidates(ipb.getChildBlocksIfBody(), cand);
+ rFindSerializationCandidates(ipb.getChildBlocksIfBody(), cand, opt);
if( ipb.getChildBlocksElseBody() != null )
- rFindSerializationCandidates(ipb.getChildBlocksElseBody(), cand);
+ rFindSerializationCandidates(ipb.getChildBlocksElseBody(), cand, opt);
}
else if( pb instanceof BasicProgramBlock ) {
BasicProgramBlock bpb = (BasicProgramBlock) pb;
@@ -885,8 +885,8 @@ public class ProgramConverter
if( !cand.contains(fkey) ) { //memoization for multiple calls, recursion
cand.add( fkey ); //add to candidates
//investigate chains of function calls
- FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(fci.getNamespace(), fci.getFunctionName());
- rFindSerializationCandidates(fpb.getChildBlocks(), cand);
+ FunctionProgramBlock fpb = pb.getProgram().getFunctionProgramBlock(fci.getNamespace(), fci.getFunctionName(), opt);
+ rFindSerializationCandidates(fpb.getChildBlocks(), cand, opt);
}
}
}
@@ -985,12 +985,12 @@ public class ProgramConverter
}
@SuppressWarnings("all")
- private static String serializeInstructions( ArrayList<Instruction> inst, HashMap<String, byte[]> clsMap )
+ private static String serializeInstructions( ArrayList<Instruction> inst, HashMap<String, byte[]> clsMap )
{
StringBuilder sb = new StringBuilder();
int count = 0;
for( Instruction linst : inst ) {
- //check that only cp instruction are transmitted
+ //check that only cp instruction are transmitted
if( !( linst instanceof CPInstruction) )
throw new DMLRuntimeException( NOT_SUPPORTED_SPARK_INSTRUCTION + " " +linst.getClass().getName()+"\n"+linst );
@@ -1098,7 +1098,6 @@ public class ProgramConverter
continue;
if( count>0 ) {
sb.append( ELEMENT_DELIM );
- sb.append( NEWLINE );
}
sb.append( pb.getKey() );
sb.append( KEY_VALUE_DELIM );
@@ -1115,7 +1114,6 @@ public class ProgramConverter
for( ProgramBlock pb : pbs ) {
if( count>0 ) {
sb.append( ELEMENT_DELIM );
- sb.append(NEWLINE);
}
sb.append( rSerializeProgramBlock(pb, clsMap) );
count++;
@@ -1339,6 +1337,10 @@ public class ProgramConverter
}
public static Program parseProgram( String in, int id ) {
+ return parseProgram(in, id, true);
+ }
+
+ public static Program parseProgram( String in, int id, boolean opt ) {
String lin = in.substring( PROG_BEGIN.length(),in.length()- PROG_END.length()).trim();
Program prog = new Program();
HashMap<String,FunctionProgramBlock> fc = parseFunctionProgramBlocks(lin, prog, id);
@@ -1346,7 +1348,7 @@ public class ProgramConverter
String[] keypart = e.getKey().split( Program.KEY_DELIM );
String namespace = keypart[0];
String name = keypart[1];
- prog.addFunctionProgramBlock(namespace, name, e.getValue());
+ prog.addFunctionProgramBlock(namespace, name, e.getValue(), opt);
}
return prog;
}
@@ -1354,7 +1356,7 @@ public class ProgramConverter
private static LocalVariableMap parseVariables(String in) {
LocalVariableMap ret = null;
if( in.length()> VARS_BEGIN.length() + VARS_END.length()) {
- String varStr = in.substring( VARS_BEGIN.length(),in.length()- VARS_END.length()).trim();
+ String varStr = in.substring( VARS_BEGIN.length(),in.length() - VARS_END.length()).trim();
ret = LocalVariableMap.deserialize(varStr);
}
else { //empty input symbol table
diff --git a/src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java b/src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java
index 665997a..bf47f19 100644
--- a/src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java
+++ b/src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java
@@ -19,8 +19,15 @@
package org.apache.sysds.test.component.paramserv;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.ObjectInput;
+import java.io.ObjectOutputStream;
+import java.io.ObjectInputStream;
import java.util.Arrays;
+import java.util.Collection;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.junit.Assert;
import org.junit.Test;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -29,32 +36,80 @@ import org.apache.sysds.runtime.instructions.cp.IntObject;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.ProgramConverter;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+@RunWith(Parameterized.class)
public class SerializationTest {
+ private int _named;
+
+ @Parameterized.Parameters
+ public static Collection named() {
+ return Arrays.asList(new Object[][] {
+ { 0 },
+ { 1 }
+ });
+ }
+
+ public SerializationTest(Integer named) {
+ this._named = named;
+ }
@Test
- public void serializeUnnamedListObject() {
+ public void serializeListObject() {
MatrixObject mo1 = generateDummyMatrix(10);
MatrixObject mo2 = generateDummyMatrix(20);
IntObject io = new IntObject(30);
- ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
- String serial = ProgramConverter.serializeDataObject("key", lo);
- Object[] obj = ProgramConverter.parseDataObject(serial);
- ListObject actualLO = (ListObject) obj[1];
- MatrixObject actualMO1 = (MatrixObject) actualLO.slice(0);
- MatrixObject actualMO2 = (MatrixObject) actualLO.slice(1);
- IntObject actualIO = (IntObject) actualLO.slice(2);
- Assert.assertArrayEquals(mo1.acquireRead().getDenseBlockValues(), actualMO1.acquireRead().getDenseBlockValues(), 0);
- Assert.assertArrayEquals(mo2.acquireRead().getDenseBlockValues(), actualMO2.acquireRead().getDenseBlockValues(), 0);
- Assert.assertEquals(io.getLongValue(), actualIO.getLongValue());
+ ListObject lot = new ListObject(Arrays.asList(mo2));
+ ListObject lo;
+
+ if (_named == 1)
+ lo = new ListObject(Arrays.asList(mo1, lot, io), Arrays.asList("e1", "e2", "e3"));
+ else
+ lo = new ListObject(Arrays.asList(mo1, lot, io));
+
+ ListObject loDeserialized = null;
+
+ // serialize and back
+ try {
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ ObjectOutputStream out = new ObjectOutputStream(bos);
+ out.writeObject(lo);
+ out.flush();
+ byte[] loBytes = bos.toByteArray();
+
+ ByteArrayInputStream bis = new ByteArrayInputStream(loBytes);
+ ObjectInput in = new ObjectInputStream(bis);
+ loDeserialized = (ListObject) in.readObject();
+ }
+ catch(Exception e){
+ System.out.println("Error while serializing and deserializing to bytes: " + e);
+ assert(false);
+ }
+
+ MatrixObject mo1Deserialized = (MatrixObject) loDeserialized.getData(0);
+ ListObject lotDeserialized = (ListObject) loDeserialized.getData(1);
+ MatrixObject mo2Deserialized = (MatrixObject) lotDeserialized.getData(0);
+ IntObject ioDeserialized = (IntObject) loDeserialized.getData(2);
+
+ if (_named == 1)
+ Assert.assertEquals(lo.getNames(), loDeserialized.getNames());
+
+ Assert.assertArrayEquals(mo1.acquireRead().getDenseBlockValues(), mo1Deserialized.acquireRead().getDenseBlockValues(), 0);
+ Assert.assertArrayEquals(mo2.acquireRead().getDenseBlockValues(), mo2Deserialized.acquireRead().getDenseBlockValues(), 0);
+ Assert.assertEquals(io.getLongValue(), ioDeserialized.getLongValue());
}
@Test
- public void serializeNamedListObject() {
+ public void serializeListObjectProgramConverter() {
MatrixObject mo1 = generateDummyMatrix(10);
MatrixObject mo2 = generateDummyMatrix(20);
IntObject io = new IntObject(30);
- ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io), Arrays.asList("e1", "e2", "e3"));
+ ListObject lo;
+ if (_named == 1)
+ lo = new ListObject(Arrays.asList(mo1, mo2, io), Arrays.asList("e1", "e2", "e3"));
+ else
+ lo = new ListObject(Arrays.asList(mo1, mo2, io));
String serial = ProgramConverter.serializeDataObject("key", lo);
Object[] obj = ProgramConverter.parseDataObject(serial);
@@ -62,7 +117,10 @@ public class SerializationTest {
MatrixObject actualMO1 = (MatrixObject) actualLO.slice(0);
MatrixObject actualMO2 = (MatrixObject) actualLO.slice(1);
IntObject actualIO = (IntObject) actualLO.slice(2);
- Assert.assertEquals(lo.getNames(), actualLO.getNames());
+
+ if (_named == 1)
+ Assert.assertEquals(lo.getNames(), actualLO.getNames());
+
Assert.assertArrayEquals(mo1.acquireRead().getDenseBlockValues(), actualMO1.acquireRead().getDenseBlockValues(), 0);
Assert.assertArrayEquals(mo2.acquireRead().getDenseBlockValues(), actualMO2.acquireRead().getDenseBlockValues(), 0);
Assert.assertEquals(io.getLongValue(), actualIO.getLongValue());
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
new file mode 100644
index 0000000..194df09
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.paramserv;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedParamservTest extends AutomatedTestBase {
+ private final static String TEST_DIR = "functions/federated/paramserv/";
+ private final static String TEST_NAME = "FederatedParamservTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedParamservTest.class.getSimpleName() + "/";
+ private final static int _blocksize = 1024;
+
+ private final String _networkType;
+ private final int _numFederatedWorkers;
+ private final int _examplesPerWorker;
+ private final int _epochs;
+ private final int _batch_size;
+ private final double _eta;
+ private final String _utype;
+ private final String _freq;
+
+ private Types.ExecMode _platformOld;
+
+ // parameters
+ @Parameterized.Parameters
+ public static Collection<Object[]> parameters() {
+ return Arrays.asList(new Object[][] {
+ //Network type, number of federated workers, examples per worker, batch size, epochs, learning rate, update type, update frequency
+ {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"},
+ {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
+ {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"},
+ {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
+ {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"},
+ {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
+ {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"},
+ {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
+ {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"},
+ {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"},
+ {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"},
+ {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"},
+ {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"},
+ {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"},
+ {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"},
+ {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"}
+ });
+ }
+
+ public FederatedParamservTest(String networkType, int numFederatedWorkers, int examplesPerWorker, int batch_size, int epochs, double eta, String utype, String freq) {
+ _networkType = networkType;
+ _numFederatedWorkers = numFederatedWorkers;
+ _examplesPerWorker = examplesPerWorker;
+ _batch_size = batch_size;
+ _epochs = epochs;
+ _eta = eta;
+ _utype = utype;
+ _freq = freq;
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+
+ _platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
+ }
+
+ @Override
+ public void tearDown() {
+
+ rtplatform = _platformOld;
+ }
+
+ @Test
+ public void federatedParamserv() {
+ // config
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ setOutputBuffering(true);
+
+ int C = 1, Hin = 28, Win = 28;
+ int numFeatures = C*Hin*Win;
+ int numLabels = 10;
+
+ // dml name
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ // generate program args
+ List<String> programArgsList = new ArrayList<>(Arrays.asList(
+ "-stats",
+ "-nvargs",
+ "examples_per_worker=" + _examplesPerWorker,
+ "num_features=" + numFeatures,
+ "num_labels=" + numLabels,
+ "epochs=" + _epochs,
+ "batch_size=" + _batch_size,
+ "eta=" + _eta,
+ "utype=" + _utype,
+ "freq=" + _freq,
+ "network_type=" + _networkType,
+ "channels=" + C,
+ "hin=" + Hin,
+ "win=" + Win
+ ));
+
+ // for each worker
+ List<Integer> ports = new ArrayList<>();
+ List<Thread> threads = new ArrayList<>();
+ for(int i = 0; i < _numFederatedWorkers; i++) {
+ // write row partitioned features to disk
+ writeInputMatrixWithMTD("X" + i, generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win), false,
+ new MatrixCharacteristics(_examplesPerWorker, numFeatures, _blocksize, _examplesPerWorker * numFeatures));
+ // write row partitioned labels to disk
+ writeInputMatrixWithMTD("y" + i, generateDummyMNISTLabels(_examplesPerWorker, numLabels), false,
+ new MatrixCharacteristics(_examplesPerWorker, numLabels, _blocksize, _examplesPerWorker * numLabels));
+
+ // start worker
+ ports.add(getRandomAvailablePort());
+ threads.add(startLocalFedWorkerThread(ports.get(i)));
+
+ // add worker to program args
+ programArgsList.add("X" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("X" + i)));
+ programArgsList.add("y" + i + "=" + TestUtils.federatedAddress(ports.get(i), input("y" + i)));
+ }
+
+ programArgs = programArgsList.toArray(new String[0]);
+ // ByteArrayOutputStream stdout =
+ runTest(null);
+ // System.out.print(stdout.toString());
+
+ // cleanup
+ for(int i = 0; i < _numFederatedWorkers; i++) {
+ TestUtils.shutdownThreads(threads.get(i));
+ }
+ }
+
+ /**
+ * Generates an feature matrix that has the same format as the MNIST dataset,
+ * but is completely random and normalized
+ *
+ * @param numExamples Number of examples to generate
+ * @param C Channels in the input data
+ * @param Hin Height in Pixels of the input data
+ * @param Win Width in Pixels of the input data
+ * @return a dummy MNIST feature matrix
+ */
+ private double[][] generateDummyMNISTFeatures(int numExamples, int C, int Hin, int Win) {
+ // Seed -1 takes the time in milliseconds as a seed
+ // Sparsity 1 means no sparsity
+ return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1);
+ }
+
+ /**
+ * Generates an label matrix that has the same format as the MNIST dataset, but is completely random and consists
+ * of one hot encoded vectors as rows
+ *
+ * @param numExamples Number of examples to generate
+ * @param numLabels Number of labels to generate
+ * @return a dummy MNIST lable matrix
+ */
+ private double[][] generateDummyMNISTLabels(int numExamples, int numLabels) {
+ // Seed -1 takes the time in milliseconds as a seed
+ // Sparsity 1 means no sparsity
+ return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1);
+ }
+}
diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml b/src/test/scripts/functions/federated/paramserv/CNN.dml
new file mode 100644
index 0000000..55d05dc
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/CNN.dml
@@ -0,0 +1,474 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * This file implements all needed functions to evaluate a convolutional neural network of the "LeNet" architecture
+ * on different execution schemes and with different inputs, for example a federated input matrix.
+ */
+
+# Imports
+source("scripts/nn/layers/affine.dml") as affine
+source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("scripts/nn/layers/dropout.dml") as dropout
+source("scripts/nn/layers/l2_reg.dml") as l2_reg
+source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("scripts/nn/layers/relu.dml") as relu
+source("scripts/nn/layers/softmax.dml") as softmax
+source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+/*
+ * Trains a convolutional net using the "LeNet" architectur single threaded the conventional way.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector. The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ * - X: Input data matrix, of shape (N, C*Hin*Win)
+ * - y: Target matrix, of shape (N, K)
+ * - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
+ * - y_val: Target validation matrix, of shape (N, K)
+ * - C: Number of input channels (dimensionality of input depth)
+ * - Hin: Input height
+ * - Win: Input width
+ * - epochs: Total number of full training loops over the full data set
+ * - batch_size: Batch size
+ * - learning_rate: The learning rate for the SGD
+ *
+ * Outputs:
+ * - model_trained: List containing
+ * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ * - b1: 1st layer biases vector, of shape (F1, 1)
+ * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ * - b2: 2nd layer biases vector, of shape (F2, 1)
+ * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3)
+ * - b3: 3rd layer biases vector, of shape (1, N3)
+ * - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ * - b4: 4th layer biases vector, of shape (1, K)
+ */
+train = function(matrix[double] X, matrix[double] y,
+ matrix[double] X_val, matrix[double] y_val,
+ int C, int Hin, int Win, int epochs, int batch_size, double learning_rate)
+ return (list[unknown] model_trained) {
+
+ N = nrow(X)
+ K = ncol(y)
+
+ # Create network:
+ ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+ Hf = 5 # filter height
+ Wf = 5 # filter width
+ stride = 1
+ pad = 2 # For same dimensions, (Hf - stride) / 2
+ F1 = 32 # num conv filters in conv1
+ F2 = 64 # num conv filters in conv2
+ N3 = 512 # num nodes in affine3
+ # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes)
+
+ [W1, b1] = conv2d::init(F1, C, Hf, Wf) # inputs: (N, C*Hin*Win)
+ [W2, b2] = conv2d::init(F2, F1, Hf, Wf) # inputs: (N, F1*(Hin/2)*(Win/2))
+ [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3) # inputs: (N, F2*(Hin/2/2)*(Win/2/2))
+ [W4, b4] = affine::init(N3, K) # inputs: (N, N3)
+ W4 = W4 / sqrt(2) # different initialization, since being fed into softmax, instead of relu
+
+ # Initialize SGD w/ Nesterov momentum optimizer
+ learning_rate = learning_rate # learning rate
+ mu = 0.9 # momentum
+ decay = 0.95 # learning rate decay constant
+ vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+ vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+ vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+ vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+ # Regularization
+ lambda = 5e-04
+
+ # Create the hyper parameter list
+ hyperparams = list(learning_rate=learning_rate, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+ # Calculate iterations
+ iters = ceil(N / batch_size)
+ print_interval = floor(iters / 25)
+
+ print("[+] Starting optimization")
+ print("[+] Learning rate: " + learning_rate)
+ print("[+] Batch size: " + batch_size)
+ print("[+] Iterations per epoch: " + iters + "\n")
+
+ for (e in 1:epochs) {
+ print("[+] Starting epoch: " + e)
+ print("|")
+ for(i in 1:iters) {
+ # Create the model list
+ model_list = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
+
+ # Get next batch
+ beg = ((i-1) * batch_size) %% N + 1
+ end = min(N, beg + batch_size - 1)
+ X_batch = X[beg:end,]
+ y_batch = y[beg:end,]
+
+ gradients_list = gradients(model_list, hyperparams, X_batch, y_batch)
+ model_updated = aggregation(model_list, hyperparams, gradients_list)
+
+ W1 = as.matrix(model_updated[1])
+ W2 = as.matrix(model_updated[2])
+ W3 = as.matrix(model_updated[3])
+ W4 = as.matrix(model_updated[4])
+ b1 = as.matrix(model_updated[5])
+ b2 = as.matrix(model_updated[6])
+ b3 = as.matrix(model_updated[7])
+ b4 = as.matrix(model_updated[8])
+ vW1 = as.matrix(model_updated[9])
+ vW2 = as.matrix(model_updated[10])
+ vW3 = as.matrix(model_updated[11])
+ vW4 = as.matrix(model_updated[12])
+ vb1 = as.matrix(model_updated[13])
+ vb2 = as.matrix(model_updated[14])
+ vb3 = as.matrix(model_updated[15])
+ vb4 = as.matrix(model_updated[16])
+ if((i %% print_interval) == 0) {
+ print("█")
+ }
+ }
+ print("|")
+ }
+
+ model_trained = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
+}
+
+/*
+ * Trains a convolutional net using the "LeNet" architecture using a parameter server with specified properties.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector. The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ * - X: Input data matrix, of shape (N, C*Hin*Win)
+ * - Y: Target matrix, of shape (N, K)
+ * - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
+ * - Y_val: Target validation matrix, of shape (N, K)
+ * - C: Number of input channels (dimensionality of input depth)
+ * - Hin: Input height
+ * - Win: Input width
+ * - epochs: Total number of full training loops over the full data set
+ * - batch_size: Batch size
+ * - learning_rate: The learning rate for the SGD
+ * - workers: Number of workers to create
+ * - utype: parameter server framework to use
+ * - scheme: update schema
+ * - mode: local or distributed
+ *
+ * Outputs:
+ * - model_trained: List containing
+ * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ * - b1: 1st layer biases vector, of shape (F1, 1)
+ * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ * - b2: 2nd layer biases vector, of shape (F2, 1)
+ * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3)
+ * - b3: 3rd layer biases vector, of shape (1, N3)
+ * - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ * - b4: 4th layer biases vector, of shape (1, K)
+ */
+train_paramserv = function(matrix[double] X, matrix[double] y,
+ matrix[double] X_val, matrix[double] y_val,
+ int C, int Hin, int Win, int epochs, int workers,
+ string utype, string freq, int batch_size, string scheme, string mode, double learning_rate)
+ return (list[unknown] model_trained) {
+
+ N = nrow(X)
+ K = ncol(y)
+
+ # Create network:
+ ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+ Hf = 5 # filter height
+ Wf = 5 # filter width
+ stride = 1
+ pad = 2 # For same dimensions, (Hf - stride) / 2
+ F1 = 32 # num conv filters in conv1
+ F2 = 64 # num conv filters in conv2
+ N3 = 512 # num nodes in affine3
+ # Note: affine4 has K nodes, which is equal to the number of target dimensions (num classes)
+
+ [W1, b1] = conv2d::init(F1, C, Hf, Wf) # inputs: (N, C*Hin*Win)
+ [W2, b2] = conv2d::init(F2, F1, Hf, Wf) # inputs: (N, F1*(Hin/2)*(Win/2))
+ [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3) # inputs: (N, F2*(Hin/2/2)*(Win/2/2))
+ [W4, b4] = affine::init(N3, K) # inputs: (N, N3)
+ W4 = W4 / sqrt(2) # different initialization, since being fed into softmax, instead of relu
+
+ # Initialize SGD w/ Nesterov momentum optimizer
+ learning_rate = learning_rate # learning rate
+ mu = 0.9 # momentum
+ decay = 0.95 # learning rate decay constant
+ vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+ vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+ vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+ vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+ # Regularization
+ lambda = 5e-04
+ # Create the model list
+ model_list = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
+ # Create the hyper parameter list
+ params = list(learning_rate=learning_rate, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+
+ # Use paramserv function
+ model_trained = paramserv(model=model_list, features=X, labels=y, val_features=X_val, val_labels=y_val, upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation", mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE")
+}
+
+/*
+ * Computes the class probability predictions of a convolutional
+ * net using the "LeNet" architecture.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.
+ *
+ * Inputs:
+ * - X: Input data matrix, of shape (N, C*Hin*Win)
+ * - C: Number of input channels (dimensionality of input depth)
+ * - Hin: Input height
+ * - Win: Input width
+ * - batch_size: Batch size
+ * - model: List containing
+ * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ * - b1: 1st layer biases vector, of shape (F1, 1)
+ * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ * - b2: 2nd layer biases vector, of shape (F2, 1)
+ * - W3: 3rd layer weights (parameters) matrix, of shape (F2*(Hin/4)*(Win/4), N3)
+ * - b3: 3rd layer biases vector, of shape (1, N3)
+ * - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ * - b4: 4th layer biases vector, of shape (1, K)
+ *
+ * Outputs:
+ * - probs: Class probabilities, of shape (N, K)
+ */
+predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size, list[unknown] model)
+ return (matrix[double] probs) {
+
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
+ N = nrow(X)
+
+ # Network:
+ ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> affine4 -> softmax
+ Hf = 5 # filter height
+ Wf = 5 # filter width
+ stride = 1
+ pad = 2 # For same dimensions, (Hf - stride) / 2
+ F1 = nrow(W1) # num conv filters in conv1
+ F2 = nrow(W2) # num conv filters in conv2
+ N3 = ncol(W3) # num nodes in affine3
+ K = ncol(W4) # num nodes in affine4, equal to number of target dimensions (num classes)
+
+ # Compute predictions over mini-batches
+ probs = matrix(0, rows=N, cols=K)
+ iters = ceil(N / batch_size)
+ parfor(i in 1:iters, check=0) {
+ # Get next batch
+ beg = ((i-1) * batch_size) %% N + 1
+ end = min(N, beg + batch_size - 1)
+ X_batch = X[beg:end,]
+
+ # Compute forward pass
+ ## layer 1: conv1 -> relu1 -> pool1
+ [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, Hf, Wf, stride, stride,
+ pad, pad)
+ outr1 = relu::forward(outc1)
+ [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
+ ## layer 2: conv2 -> relu2 -> pool2
+ [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+ stride, stride, pad, pad)
+ outr2 = relu::forward(outc2)
+ [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2, 2, 2, 2, 0, 0)
+ ## layer 3: affine3 -> relu3
+ outa3 = affine::forward(outp2, W3, b3)
+ outr3 = relu::forward(outa3)
+ ## layer 4: affine4 -> softmax
+ outa4 = affine::forward(outr3, W4, b4)
+ probs_batch = softmax::forward(outa4)
+
+ # Store predictions
+ probs[beg:end,] = probs_batch
+ }
+}
+
+/*
+ * Evaluates a convolutional net using the "LeNet" architecture.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples. The targets, y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ * - probs: Class probabilities, of shape (N, K)
+ * - y: Target matrix, of shape (N, K)
+ *
+ * Outputs:
+ * - loss: Scalar loss, of shape (1)
+ * - accuracy: Scalar accuracy, of shape (1)
+ */
+eval = function(matrix[double] probs, matrix[double] y)
+ return (double loss, double accuracy) {
+
+ # Compute loss & accuracy
+ loss = cross_entropy_loss::forward(probs, y)
+ correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+ accuracy = mean(correct_pred)
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+ list[unknown] hyperparams,
+ matrix[double] features,
+ matrix[double] labels)
+ return (list[unknown] gradients) {
+
+ C = as.integer(as.scalar(hyperparams["C"]))
+ Hin = as.integer(as.scalar(hyperparams["Hin"]))
+ Win = as.integer(as.scalar(hyperparams["Win"]))
+ Hf = as.integer(as.scalar(hyperparams["Hf"]))
+ Wf = as.integer(as.scalar(hyperparams["Wf"]))
+ stride = as.integer(as.scalar(hyperparams["stride"]))
+ pad = as.integer(as.scalar(hyperparams["pad"]))
+ lambda = as.double(as.scalar(hyperparams["lambda"]))
+ F1 = as.integer(as.scalar(hyperparams["F1"]))
+ F2 = as.integer(as.scalar(hyperparams["F2"]))
+ N3 = as.integer(as.scalar(hyperparams["N3"]))
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
+
+ # Compute forward pass
+ ## layer 1: conv1 -> relu1 -> pool1
+ [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, Wf,
+ stride, stride, pad, pad)
+ outr1 = relu::forward(outc1)
+ [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
+ ## layer 2: conv2 -> relu2 -> pool2
+ [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, Hf, Wf,
+ stride, stride, pad, pad)
+ outr2 = relu::forward(outc2)
+ [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2, 2, 2, 2, 0, 0)
+ ## layer 3: affine3 -> relu3 -> dropout
+ outa3 = affine::forward(outp2, W3, b3)
+ outr3 = relu::forward(outa3)
+ [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+ ## layer 4: affine4 -> softmax
+ outa4 = affine::forward(outd3, W4, b4)
+ probs = softmax::forward(outa4)
+
+ # Compute loss & accuracy for training data
+ loss = cross_entropy_loss::forward(probs, labels)
+ accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
+ print("[+] Completed forward pass on batch: train loss: " + loss + ", train accuracy: " + accuracy)
+
+ # Compute data backward pass
+ ## loss
+ dprobs = cross_entropy_loss::backward(probs, labels)
+ ## layer 4: affine4 -> softmax
+ douta4 = softmax::backward(dprobs, outa4)
+ [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+ ## layer 3: affine3 -> relu3 -> dropout
+ doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+ douta3 = relu::backward(doutr3, outa3)
+ [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+ ## layer 2: conv2 -> relu2 -> pool2
+ doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, Woutc2, 2, 2, 2, 2, 0, 0)
+ doutc2 = relu::backward(doutr2, outc2)
+ [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, F1,
+ Houtp1, Woutp1, Hf, Wf, stride, stride, pad, pad)
+ ## layer 1: conv1 -> relu1 -> pool1
+ doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, Woutc1, 2, 2, 2, 2, 0, 0)
+ doutc1 = relu::backward(doutr1, outc1)
+ [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, W1, b1, C, Hin, Win,
+ Hf, Wf, stride, stride, pad, pad)
+
+ # Compute regularization backward pass
+ dW1_reg = l2_reg::backward(W1, lambda)
+ dW2_reg = l2_reg::backward(W2, lambda)
+ dW3_reg = l2_reg::backward(W3, lambda)
+ dW4_reg = l2_reg::backward(W4, lambda)
+ dW1 = dW1 + dW1_reg
+ dW2 = dW2 + dW2_reg
+ dW3 = dW3 + dW3_reg
+ dW4 = dW4 + dW4_reg
+
+ gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+ list[unknown] hyperparams,
+ list[unknown] gradients)
+ return (list[unknown] model_result) {
+
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
+ dW1 = as.matrix(gradients[1])
+ dW2 = as.matrix(gradients[2])
+ dW3 = as.matrix(gradients[3])
+ dW4 = as.matrix(gradients[4])
+ db1 = as.matrix(gradients[5])
+ db2 = as.matrix(gradients[6])
+ db3 = as.matrix(gradients[7])
+ db4 = as.matrix(gradients[8])
+ vW1 = as.matrix(model[9])
+ vW2 = as.matrix(model[10])
+ vW3 = as.matrix(model[11])
+ vW4 = as.matrix(model[12])
+ vb1 = as.matrix(model[13])
+ vb2 = as.matrix(model[14])
+ vb3 = as.matrix(model[15])
+ vb4 = as.matrix(model[16])
+ learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
+ mu = as.double(as.scalar(hyperparams["mu"]))
+
+ # Optimize with SGD w/ Nesterov momentum
+ [W1, vW1] = sgd_nesterov::update(W1, dW1, learning_rate, mu, vW1)
+ [b1, vb1] = sgd_nesterov::update(b1, db1, learning_rate, mu, vb1)
+ [W2, vW2] = sgd_nesterov::update(W2, dW2, learning_rate, mu, vW2)
+ [b2, vb2] = sgd_nesterov::update(b2, db2, learning_rate, mu, vb2)
+ [W3, vW3] = sgd_nesterov::update(W3, dW3, learning_rate, mu, vW3)
+ [b3, vb3] = sgd_nesterov::update(b3, db3, learning_rate, mu, vb3)
+ [W4, vW4] = sgd_nesterov::update(W4, dW4, learning_rate, mu, vW4)
+ [b4, vb4] = sgd_nesterov::update(b4, db4, learning_rate, mu, vb4)
+
+ model_result = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
+}
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
new file mode 100644
index 0000000..16c72c4
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -0,0 +1,57 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN
+source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
+
+# create federated input matrices
+features = federated(addresses=list($X0, $X1),
+ ranges=list(list(0, 0), list($examples_per_worker, $num_features),
+ list($examples_per_worker, 0), list($examples_per_worker * 2, $num_features)))
+
+labels = federated(addresses=list($y0, $y1),
+ ranges=list(list(0, 0), list($examples_per_worker, $num_labels),
+ list($examples_per_worker, 0), list($examples_per_worker * 2, $num_labels)))
+
+epochs = $epochs
+batch_size = $batch_size
+learning_rate = $eta
+utype = $utype
+freq = $freq
+network_type = $network_type
+
+# currently ignored parameters
+workers = 1
+scheme = "DISJOINT_CONTIGUOUS"
+paramserv_mode = "LOCAL"
+
+# config for the cnn
+channels = $channels
+hin = $hin
+win = $win
+
+if(network_type == "TwoNN") {
+ model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), epochs, workers, utype, freq, batch_size, scheme, paramserv_mode, learning_rate)
+}
+else {
+ model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), matrix(0, rows=0, cols=0), channels, hin, win, epochs, workers, utype, freq, batch_size, scheme, paramserv_mode, learning_rate)
+}
+print(toString(model))
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
new file mode 100644
index 0000000..3bcfe84
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
@@ -0,0 +1,299 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * This file implements all needed functions to evaluate a simple feed forward neural network
+ * on different execution schemes and with different inputs, for example a federated input matrix.
+ */
+
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd.dml") as sgd
+
+/*
+ * Trains a simple feed forward neural network with two hidden layers single threaded the conventional way.
+ *
+ * The input matrix has one example per row (N) and D features.
+ * The targets, y, have K classes, and are one-hot encoded.
+ *
+ * Inputs:
+ * - X: Input data matrix of shape (N, D)
+ * - y: Target matrix of shape (N, K)
+ * - X_val: Input validation data matrix of shape (N_val, D)
+ * - y_val: Targed validation matrix of shape (N_val, K)
+ * - epochs: Total number of full training loops over the full data set
+ * - batch_size: Batch size
+ * - learning_rate: The learning rate for the SGD
+ *
+ * Outputs:
+ * - model_trained: List containing
+ * - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ * - b1: 1st layer biases vector, of shape (200, 1)
+ * - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ * - b2: 2nd layer biases vector, of shape (200, 1)
+ * - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ * - b3: 3rd layer biases vector, of shape (K, 1)
+ */
+train = function(matrix[double] X, matrix[double] y,
+ matrix[double] X_val, matrix[double] y_val,
+ int epochs, int batch_size, double learning_rate)
+ return (list[unknown] model_trained) {
+
+ N = nrow(X) # num examples
+ D = ncol(X) # num features
+ K = ncol(y) # num classes
+
+ # Create the network:
+ ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+ [W1, b1] = affine::init(D, 200)
+ [W2, b2] = affine::init(200, 200)
+ [W3, b3] = affine::init(200, K)
+ W3 = W3 / sqrt(2) # different initialization, since being fed into softmax, instead of relu
+
+ # Create the hyper parameter list
+ hyperparams = list(learning_rate=learning_rate)
+ # Calculate iterations
+ iters = ceil(N / batch_size)
+ print_interval = floor(iters / 25)
+
+ print("[+] Starting optimization")
+ print("[+] Learning rate: " + learning_rate)
+ print("[+] Batch size: " + batch_size)
+ print("[+] Iterations per epoch: " + iters + "\n")
+
+ for (e in 1:epochs) {
+ print("[+] Starting epoch: " + e)
+ print("|")
+ for(i in 1:iters) {
+ # Create the model list
+ model_list = list(W1, W2, W3, b1, b2, b3)
+
+ # Get next batch
+ beg = ((i-1) * batch_size) %% N + 1
+ end = min(N, beg + batch_size - 1)
+ X_batch = X[beg:end,]
+ y_batch = y[beg:end,]
+
+ gradients_list = gradients(model_list, hyperparams, X_batch, y_batch)
+ model_updated = aggregation(model_list, hyperparams, gradients_list)
+
+ W1 = as.matrix(model_updated[1])
+ W2 = as.matrix(model_updated[2])
+ W3 = as.matrix(model_updated[3])
+ b1 = as.matrix(model_updated[4])
+ b2 = as.matrix(model_updated[5])
+ b3 = as.matrix(model_updated[6])
+
+ if((i %% print_interval) == 0) {
+ print("█")
+ }
+ }
+ print("|")
+ }
+
+ model_trained = list(W1, W2, W3, b1, b2, b3)
+}
+
+/*
+ * Trains a simple feed forward neural network with two hidden layers
+ * using a parameter server with specified properties.
+ *
+ * The input matrix has one example per row (N) and D features.
+ * The targets, y, have K classes, and are one-hot encoded.
+ *
+ * Inputs:
+ * - X: Input data matrix of shape (N, D)
+ * - y: Target matrix of shape (N, K)
+ * - X_val: Input validation data matrix of shape (N_val, D)
+ * - y_val: Targed validation matrix of shape (N_val, K)
+ * - epochs: Total number of full training loops over the full data set
+ * - batch_size: Batch size
+ * - learning_rate: The learning rate for the SGD
+ * - workers: Number of workers to create
+ * - utype: parameter server framework to use
+ * - scheme: update schema
+ * - mode: local or distributed
+ *
+ * Outputs:
+ * - model_trained: List containing
+ * - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ * - b1: 1st layer biases vector, of shape (200, 1)
+ * - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ * - b2: 2nd layer biases vector, of shape (200, 1)
+ * - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ * - b3: 3rd layer biases vector, of shape (K, 1)
+ */
+train_paramserv = function(matrix[double] X, matrix[double] y,
+ matrix[double] X_val, matrix[double] y_val,
+ int epochs, int workers,
+ string utype, string freq, int batch_size, string scheme, string mode, double learning_rate)
+ return (list[unknown] model_trained) {
+
+ N = nrow(X) # num examples
+ D = ncol(X) # num features
+ K = ncol(y) # num classes
+
+ # Create the network:
+ ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+ [W1, b1] = affine::init(D, 200)
+ [W2, b2] = affine::init(200, 200)
+ [W3, b3] = affine::init(200, K)
+
+ # Create the model list
+ model_list = list(W1, W2, W3, b1, b2, b3)
+ # Create the hyper parameter list
+ params = list(learning_rate=learning_rate)
+ # Use paramserv function
+ model_trained = paramserv(model=model_list, features=X, labels=y, val_features=X_val, val_labels=y_val, upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients", agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation", mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE")
+}
+
+/*
+ * Computes the class probability predictions of a simple feed forward neural network.
+ *
+ * Inputs:
+ * - X: The input data matrix of shape (N, D)
+ * - model: List containing
+ * - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ * - b1: 1st layer biases vector, of shape (200, 1)
+ * - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ * - b2: 2nd layer biases vector, of shape (200, 1)
+ * - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ * - b3: 3rd layer biases vector, of shape (K, 1)
+ *
+ * Outputs:
+ * - probs: Class probabilities, of shape (N, K)
+ */
+predict = function(matrix[double] X,
+ list[unknown] model)
+ return (matrix[double] probs) {
+
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ b1 = as.matrix(model[4])
+ b2 = as.matrix(model[5])
+ b3 = as.matrix(model[6])
+
+ out1relu = relu::forward(affine::forward(X, W1, b1))
+ out2relu = relu::forward(affine::forward(out1relu, W2, b2))
+ probs = softmax::forward(affine::forward(out2relu, W3, b3))
+}
+
+/*
+ * Evaluates a simple feed forward neural network.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples. The targets, y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ * - probs: Class probabilities, of shape (N, K).
+ * - y: Target matrix, of shape (N, K).
+ *
+ * Outputs:
+ * - loss: Scalar loss, of shape (1).
+ * - accuracy: Scalar accuracy, of shape (1).
+ */
+eval = function(matrix[double] probs, matrix[double] y)
+ return (double loss, double accuracy) {
+
+ # Compute loss & accuracy
+ loss = cross_entropy_loss::forward(probs, y)
+ correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+ accuracy = mean(correct_pred)
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+ list[unknown] hyperparams,
+ matrix[double] features,
+ matrix[double] labels)
+ return (list[unknown] gradients) {
+
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ b1 = as.matrix(model[4])
+ b2 = as.matrix(model[5])
+ b3 = as.matrix(model[6])
+
+ # Compute forward pass
+ ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+ out1 = affine::forward(features, W1, b1)
+ out1relu = relu::forward(out1)
+ out2 = affine::forward(out1relu, W2, b2)
+ out2relu = relu::forward(out2)
+ out3 = affine::forward(out2relu, W3, b3)
+ probs = softmax::forward(out3)
+
+ # Compute loss & accuracy for training data
+ loss = cross_entropy_loss::forward(probs, labels)
+ accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
+ print("[+] Completed forward pass on batch: train loss: " + loss + ", train accuracy: " + accuracy)
+
+ # Compute data backward pass
+ dprobs = cross_entropy_loss::backward(probs, labels)
+ dout3 = softmax::backward(dprobs, out3)
+ [dout2relu, dW3, db3] = affine::backward(dout3, out2relu, W3, b3)
+ dout2 = relu::backward(dout2relu, out2)
+ [dout1relu, dW2, db2] = affine::backward(dout2, out1relu, W2, b2)
+ dout1 = relu::backward(dout1relu, out1)
+ [dfeatures, dW1, db1] = affine::backward(dout1, features, W1, b1)
+
+ gradients = list(dW1, dW2, dW3, db1, db2, db3)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+ list[unknown] hyperparams,
+ list[unknown] gradients)
+ return (list[unknown] model_result) {
+
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ b1 = as.matrix(model[4])
+ b2 = as.matrix(model[5])
+ b3 = as.matrix(model[6])
+ dW1 = as.matrix(gradients[1])
+ dW2 = as.matrix(gradients[2])
+ dW3 = as.matrix(gradients[3])
+ db1 = as.matrix(gradients[4])
+ db2 = as.matrix(gradients[5])
+ db3 = as.matrix(gradients[6])
+ learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
+
+ # Optimize with SGD
+ W3 = sgd::update(W3, dW3, learning_rate)
+ b3 = sgd::update(b3, db3, learning_rate)
+ W2 = sgd::update(W2, dW2, learning_rate)
+ b2 = sgd::update(b2, db2, learning_rate)
+ W1 = sgd::update(W1, dW1, learning_rate)
+ b1 = sgd::update(b1, db1, learning_rate)
+
+ model_result = list(W1, W2, W3, b1, b2, b3)
+}
\ No newline at end of file