You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/08/08 18:52:50 UTC
[systemds] branch master updated: [SYSTEMDS-2600] Rework federated
runtime backend (framework, ops)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new a4f992e [SYSTEMDS-2600] Rework federated runtime backend (framework, ops)
a4f992e is described below
commit a4f992ed86d92cff95b160dab0c852b5434bed25
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Aug 8 20:51:25 2020 +0200
[SYSTEMDS-2600] Rework federated runtime backend (framework, ops)
This patch makes a major rework of the exiting federated runtime backend
and operations in order to simplify the joint development of all
remaining federated operations.
The new design has only four command types: read, put, get, exec_inst,
which allows to read federated matrices, put and get variables, and
execute arbitrary instructions over these variables. With this approach,
we can reuse the existing symbol table and CP/Spark instructions and
only need to handle their orchestration and global compensations.
Furthermore, the new design adds several primitives like broadcast,
broadcastSliced, aggregations, and rbind/cbind and more convenient data
structures.
Finally, this patch also includes minor reworks of the execution
context, and reblock rewrite to allow for specific characteristics of
federated execution.
---
src/main/java/org/apache/sysds/common/Types.java | 2 +-
.../hops/rewrite/RewriteBlockSizeAndReblock.java | 5 +-
.../controlprogram/caching/CacheableData.java | 10 +-
.../controlprogram/caching/FrameObject.java | 3 +-
.../controlprogram/caching/MatrixObject.java | 4 +-
.../controlprogram/context/ExecutionContext.java | 56 +++-
.../context/SparkExecutionContext.java | 6 +-
.../controlprogram/federated/FederatedData.java | 57 +---
.../controlprogram/federated/FederatedRange.java | 10 +
.../controlprogram/federated/FederatedRequest.java | 50 ++-
.../federated/FederatedResponse.java | 22 +-
.../controlprogram/federated/FederatedWorker.java | 24 +-
.../federated/FederatedWorkerHandler.java | 278 ++++++----------
.../federated/FederatedWorkerHandlerException.java | 4 +
.../controlprogram/federated/FederationMap.java | 153 +++++++++
.../controlprogram/federated/FederationUtils.java | 125 +++++++
.../controlprogram/federated/LibFederatedAgg.java | 103 ------
.../federated/LibFederatedAppend.java | 80 -----
.../cp/MatrixIndexingCPInstruction.java | 2 +-
.../instructions/cp/VariableCPInstruction.java | 7 +
.../fed/AggregateBinaryFEDInstruction.java | 359 ++-------------------
.../fed/AggregateUnaryFEDInstruction.java | 60 ++--
.../instructions/fed/AppendFEDInstruction.java | 72 +++--
.../fed/BinaryMatrixScalarFEDInstruction.java | 75 ++---
.../instructions/fed/InitFEDInstruction.java | 16 +-
.../matrix/operators/AggregateUnaryOperator.java | 10 +
.../apache/sysds/runtime/util/UtilFunctions.java | 26 --
.../org/apache/sysds/test/AutomatedTestBase.java | 2 +-
.../federated/FederatedConstructionTest.java | 4 +-
.../functions/federated/FederatedMultiplyTest.java | 6 +-
.../functions/federated/FederatedRCBindTest.java | 7 +-
.../test/functions/federated/FederatedSumTest.java | 4 +-
.../functions/federated/FederatedSumTest.dml | 1 +
.../FederatedMatrixAdditionScalar.dml | 1 +
34 files changed, 692 insertions(+), 952 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index b3f8de1..92027a5 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -37,7 +37,7 @@ public class Types
/**
* Execution type of individual operations.
*/
- public enum ExecType { CP, CP_FILE, SPARK, GPU, INVALID }
+ public enum ExecType { CP, CP_FILE, SPARK, GPU, FED, INVALID }
/**
* Data types (tensor, matrix, scalar, frame, object, unknown).
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java
index 1da6aa5..cc24919 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteBlockSizeAndReblock.java
@@ -134,8 +134,9 @@ public class RewriteBlockSizeAndReblock extends HopRewriteRule
}
}
else if (dop.getOp() == OpOpData.FEDERATED) {
- // TODO maybe do something here?
- } else {
+ dop.setBlocksize(blocksize);
+ }
+ else {
throw new HopsException(hop.printErrorLocation() + "unexpected non-scalar Data HOP in reblock.\n");
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 602393e..590b3e5 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -32,8 +32,7 @@ import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer.RPolicy;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.instructions.cp.Data;
@@ -170,8 +169,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
*/
protected PrivacyConstraint _privacyConstraint = null;
- protected Map<FederatedRange, FederatedData> _fedMapping = null;
-
+ protected FederationMap _fedMapping = null;
/** The name of HDFS file in which the data is backed up. */
protected String _hdfsFileName = null; // file name and path
@@ -357,7 +355,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
* Gets the mapping of indices ranges to federated objects.
* @return fedMapping mapping
*/
- public Map<FederatedRange, FederatedData> getFedMapping() {
+ public FederationMap getFedMapping() {
return _fedMapping;
}
@@ -365,7 +363,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
* Sets the mapping of indices ranges to federated objects.
* @param fedMapping mapping
*/
- public void setFedMapping(Map<FederatedRange, FederatedData> fedMapping) {
+ public void setFedMapping(FederationMap fedMapping) {
_fedMapping = fedMapping;
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
index 19c33a9..ef6e790 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
@@ -47,7 +47,6 @@ import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Future;
-import static org.apache.sysds.runtime.util.UtilFunctions.requestFederatedData;
public class FrameObject extends CacheableData<FrameBlock>
{
@@ -169,7 +168,7 @@ public class FrameObject extends CacheableData<FrameBlock>
FrameBlock result = new FrameBlock(_schema);
// provide long support?
result.ensureAllocatedColumns((int) _metaData.getDataCharacteristics().getRows());
- List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = requestFederatedData(_fedMapping);
+ List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = _fedMapping.requestFederatedData();
try {
for(Pair<FederatedRange, Future<FederatedResponse>> readResponse : readResponses) {
FederatedRange range = readResponse.getLeft();
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index 7509c02..6216ba5 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -19,8 +19,6 @@
package org.apache.sysds.runtime.controlprogram.caching;
-import static org.apache.sysds.runtime.util.UtilFunctions.requestFederatedData;
-
import java.io.IOException;
import java.lang.ref.SoftReference;
import java.util.List;
@@ -405,7 +403,7 @@ public class MatrixObject extends CacheableData<MatrixBlock>
long[] dims = getDataCharacteristics().getDims();
// TODO sparse optimization
MatrixBlock result = new MatrixBlock((int) dims[0], (int) dims[1], false);
- List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = requestFederatedData(_fedMapping);
+ List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = _fedMapping.requestFederatedData();
try {
for (Pair<FederatedRange, Future<FederatedResponse>> readResponse : readResponses) {
FederatedRange range = readResponse.getLeft();
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 3354eb8..7be3bfd 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -22,11 +22,15 @@ package org.apache.sysds.runtime.controlprogram.context;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.Program;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -66,7 +70,8 @@ public class ExecutionContext {
//symbol table
protected LocalVariableMap _variables;
-
+ protected boolean _autoCreateVars;
+
//lineage map, cache, prepared dedup blocks
protected Lineage _lineage;
@@ -83,12 +88,14 @@ public class ExecutionContext {
protected ExecutionContext( boolean allocateVariableMap, boolean allocateLineage, Program prog ) {
//protected constructor to force use of ExecutionContextFactory
_variables = allocateVariableMap ? new LocalVariableMap() : null;
+ _autoCreateVars = false;
_lineage = allocateLineage ? new Lineage() : null;
_prog = prog;
}
public ExecutionContext(LocalVariableMap vars) {
_variables = vars;
+ _autoCreateVars = false;
_lineage = null;
_prog = null;
}
@@ -116,6 +123,14 @@ public class ExecutionContext {
public void setLineage(Lineage lineage) {
_lineage = lineage;
}
+
+ public boolean isAutoCreateVars() {
+ return _autoCreateVars;
+ }
+
+ public void setAutoCreateVars(boolean flag) {
+ _autoCreateVars = flag;
+ }
/**
* Get the i-th GPUContext
@@ -502,6 +517,8 @@ public class ExecutionContext {
}
public void setMatrixOutput(String varName, MatrixBlock outputData) {
+ if( isAutoCreateVars() && !containsVariable(varName) )
+ setVariable(varName, createMatrixObject(outputData));
MatrixObject mo = getMatrixObject(varName);
mo.acquireModify(outputData);
mo.release();
@@ -509,6 +526,8 @@ public class ExecutionContext {
}
public void setMatrixOutput(String varName, MatrixBlock outputData, UpdateType flag) {
+ if( isAutoCreateVars() && !containsVariable(varName) )
+ setVariable(varName, createMatrixObject(outputData));
if( flag.isInPlace() ) {
//modify metadata to carry update status
MatrixObject mo = getMatrixObject(varName);
@@ -517,10 +536,6 @@ public class ExecutionContext {
setMatrixOutput(varName, outputData);
}
- public void setMatrixOutput(String varName, MatrixBlock outputData, UpdateType flag, String opcode) {
- setMatrixOutput(varName, outputData, flag);
- }
-
public void setTensorOutput(String varName, TensorBlock outputData) {
TensorObject to = getTensorObject(varName);
to.acquireModify(outputData);
@@ -529,11 +544,42 @@ public class ExecutionContext {
}
public void setFrameOutput(String varName, FrameBlock outputData) {
+ if( isAutoCreateVars() && !containsVariable(varName) )
+ setVariable(varName, createFrameObject(outputData));
FrameObject fo = getFrameObject(varName);
fo.acquireModify(outputData);
fo.release();
setVariable(varName, fo);
}
+
+ public static CacheableData<?> createCacheableData(CacheBlock cb) {
+ if( cb instanceof MatrixBlock )
+ return createMatrixObject((MatrixBlock) cb);
+ else if( cb instanceof FrameBlock )
+ return createFrameObject((FrameBlock) cb);
+ return null;
+ }
+
+ private static CacheableData<?> createMatrixObject(MatrixBlock mb) {
+ MatrixObject ret = new MatrixObject(Types.ValueType.FP64,
+ OptimizerUtils.getUniqueTempFileName());
+ ret.acquireModify(mb);
+ ret.setMetaData(new MetaDataFormat(new MatrixCharacteristics(
+ mb.getNumRows(), mb.getNumColumns()), FileFormat.BINARY));
+ ret.getMetaData().getDataCharacteristics()
+ .setBlocksize(ConfigurationManager.getBlocksize());
+ ret.release();
+ return ret;
+ }
+
+ private static CacheableData<?> createFrameObject(FrameBlock fb) {
+ FrameObject ret = new FrameObject(OptimizerUtils.getUniqueTempFileName());
+ ret.acquireModify(fb);
+ ret.setMetaData(new MetaDataFormat(new MatrixCharacteristics(
+ fb.getNumRows(), fb.getNumColumns()), FileFormat.BINARY));
+ ret.release();
+ return ret;
+ }
public List<MatrixBlock> getMatrixInputs(CPOperand[] inputs) {
return getMatrixInputs(inputs, false);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 11a4e93..510113c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -381,14 +381,14 @@ public class SparkExecutionContext extends ExecutionContext
rdd = mo.getRDDHandle().getRDD();
}
//CASE 2: dirty in memory data or cached result of rdd operations
- else if( mo.isDirty() || mo.isCached(false) )
+ else if( mo.isDirty() || mo.isCached(false) || mo.isFederated() )
{
//get in-memory matrix block and parallelize it
//w/ guarded parallelize (fallback to export, rdd from file if too large)
DataCharacteristics dc = mo.getDataCharacteristics();
boolean fromFile = false;
- if( !OptimizerUtils.checkSparkCollectMemoryBudget(dc, 0) || !_parRDDs.reserve(
- OptimizerUtils.estimatePartitionedSizeExactSparsity(dc))) {
+ if( !mo.isFederated() && (!OptimizerUtils.checkSparkCollectMemoryBudget(dc, 0)
+ || !_parRDDs.reserve(OptimizerUtils.estimatePartitionedSizeExactSparsity(dc)))) {
if( mo.isDirty() || !mo.isHDFSFileExists() ) //write if necessary
mo.exportData();
rdd = sc.hadoopFile( mo.getFileName(), inputInfo.inputFormatClass, inputInfo.keyClass, inputInfo.valueClass);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 32a3457..1d5f5df 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -35,6 +35,7 @@ import io.netty.util.concurrent.Promise;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import java.net.InetSocketAddress;
import java.util.concurrent.Future;
@@ -83,51 +84,15 @@ public class FederatedData {
return _varID != -1;
}
- public synchronized Future<FederatedResponse> initFederatedData() {
+ public synchronized Future<FederatedResponse> initFederatedData(long id) {
if(isInitialized())
throw new DMLRuntimeException("Tried to init already initialized data");
- FederatedRequest.FedMethod fedMethod;
- switch(_dataType) {
- case MATRIX:
- fedMethod = FederatedRequest.FedMethod.READ_MATRIX;
- break;
- case FRAME:
- fedMethod = FederatedRequest.FedMethod.READ_FRAME;
- break;
- default:
- throw new DMLRuntimeException("Federated datatype \"" + _dataType.toString() + "\" is not supported.");
- }
- FederatedRequest request = new FederatedRequest(fedMethod);
+ if(!_dataType.isMatrix() && !_dataType.isFrame())
+ throw new DMLRuntimeException("Federated datatype \"" + _dataType.toString() + "\" is not supported.");
+ _varID = id;
+ FederatedRequest request = new FederatedRequest(RequestType.READ_VAR, id);
request.appendParam(_filepath);
- return executeFederatedOperation(request);
- }
-
- /**
- * Executes an federated operation on a federated worker and default variable.
- *
- * @param request the requested operation
- * @param withVarID true if we should add the default varID (initialized) or false if we should not
- * @return the response
- */
- public Future<FederatedResponse> executeFederatedOperation(FederatedRequest request, boolean withVarID) {
- if (withVarID) {
- if( !isInitialized() )
- throw new DMLRuntimeException("Tried to execute federated operation on data non initialized federated data.");
- return executeFederatedOperation(request, _varID);
- }
- return executeFederatedOperation(request);
- }
-
- /**
- * Executes an federated operation on a federated worker.
- *
- * @param request the requested operation
- * @param varID variable ID
- * @return the response
- */
- public Future<FederatedResponse> executeFederatedOperation(FederatedRequest request, long varID) {
- request = request.deepClone();
- request.appendParam(varID);
+ request.appendParam(_dataType.name());
return executeFederatedOperation(request);
}
@@ -137,7 +102,7 @@ public class FederatedData {
* @param request the requested operation
* @return the response
*/
- public synchronized Future<FederatedResponse> executeFederatedOperation(FederatedRequest request) {
+ public synchronized Future<FederatedResponse> executeFederatedOperation(FederatedRequest... request) {
// Careful with the number of threads. Each thread opens connections to multiple files making resulting in
// java.io.IOException: Too many open files
EventLoopGroup workerGroup = new NioEventLoopGroup(_nrThreads);
@@ -148,9 +113,9 @@ public class FederatedData {
@Override
public void initChannel(SocketChannel ch) {
ch.pipeline().addLast("ObjectDecoder",
- new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader())))
- .addLast("FederatedOperationHandler", handler)
- .addLast("ObjectEncoder", new ObjectEncoder());
+ new ObjectDecoder(Integer.MAX_VALUE, ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader())))
+ .addLast("FederatedOperationHandler", handler)
+ .addLast("ObjectEncoder", new ObjectEncoder());
}
});
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index d8e20a5..b4f69ad 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -92,4 +92,14 @@ public class FederatedRange implements Comparable<FederatedRange> {
public String toString() {
return Arrays.toString(_beginDims) + " - " + Arrays.toString(_endDims);
}
+
+ public FederatedRange shift(long rshift, long cshift) {
+ //row shift
+ _beginDims[0] += rshift;
+ _endDims[0] += rshift;
+ //column shift
+ _beginDims[1] += cshift;
+ _endDims[1] += cshift;
+ return this;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 771f828..f2d53e4 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -29,36 +29,47 @@ import org.apache.sysds.api.DMLScript;
public class FederatedRequest implements Serializable {
private static final long serialVersionUID = 5946781306963870394L;
- public enum FedMethod {
- READ_MATRIX, READ_FRAME, MATVECMULT, TRANSFER, AGGREGATE, SCALAR
+ // commands sent to and excuted by federated workers
+ public enum RequestType {
+ READ_VAR, // create variable for local data, read on first access
+ PUT_VAR, // receive data from main and store to local variable
+ GET_VAR, // return local variable to main
+ EXEC_INST // execute arbitrary instruction over
}
- private FedMethod _method;
+ private RequestType _method;
+ private long _id;
private List<Object> _data;
private boolean _checkPrivacy;
- public FederatedRequest(FedMethod method, List<Object> data) {
- _method = method;
- _data = data;
- setCheckPrivacy();
+
+ public FederatedRequest(RequestType method) {
+ this(method, FederationUtils.getNextFedDataID(), new ArrayList<>());
}
- public FederatedRequest(FedMethod method, Object ... datas) {
- _method = method;
- _data = Arrays.asList(datas);
- setCheckPrivacy();
+ public FederatedRequest(RequestType method, long id) {
+ this(method, id, new ArrayList<>());
}
- public FederatedRequest(FedMethod method) {
+ public FederatedRequest(RequestType method, long id, Object ... data) {
+ this(method, id, Arrays.asList(data));
+ }
+
+ public FederatedRequest(RequestType method, long id, List<Object> data) {
_method = method;
- _data = new ArrayList<>();
+ _id = id;
+ _data = data;
setCheckPrivacy();
}
- public FedMethod getMethod() {
+ public RequestType getType() {
return _method;
}
+ public long getID() {
+ return _id;
+ }
+
public Object getParam(int i) {
return _data.get(i);
}
@@ -78,7 +89,7 @@ public class FederatedRequest implements Serializable {
}
public FederatedRequest deepClone() {
- return new FederatedRequest(_method, new ArrayList<>(_data));
+ return new FederatedRequest(_method, _id, new ArrayList<>(_data));
}
public void setCheckPrivacy(boolean checkPrivacy){
@@ -92,4 +103,13 @@ public class FederatedRequest implements Serializable {
public boolean checkPrivacy(){
return _checkPrivacy;
}
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder("FederatedRequest[");
+ sb.append(_method); sb.append(";");
+ sb.append(_id); sb.append(";");
+ sb.append(Arrays.toString(_data.toArray())); sb.append("]");
+ return sb.toString();
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
index 3335aae..ea03bd4 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
@@ -32,36 +32,36 @@ import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
public class FederatedResponse implements Serializable {
private static final long serialVersionUID = 3142180026498695091L;
- public enum Type {
+ public enum ResponseType {
SUCCESS,
SUCCESS_EMPTY,
ERROR,
}
- private FederatedResponse.Type _status;
+ private ResponseType _status;
private Object[] _data;
private Map<PrivacyLevel,LongAdder> checkedConstraints;
- public FederatedResponse(FederatedResponse.Type status) {
+ public FederatedResponse(ResponseType status) {
this(status, null);
}
- public FederatedResponse(FederatedResponse.Type status, Object[] data) {
+ public FederatedResponse(ResponseType status, Object[] data) {
_status = status;
_data = data;
- if( _status == FederatedResponse.Type.SUCCESS && data == null )
- _status = FederatedResponse.Type.SUCCESS_EMPTY;
+ if( _status == ResponseType.SUCCESS && data == null )
+ _status = ResponseType.SUCCESS_EMPTY;
}
- public FederatedResponse(FederatedResponse.Type status, Object data) {
+ public FederatedResponse(FederatedResponse.ResponseType status, Object data) {
_status = status;
_data = new Object[] {data};
- if(_status == FederatedResponse.Type.SUCCESS && data == null)
- _status = FederatedResponse.Type.SUCCESS_EMPTY;
+ if(_status == ResponseType.SUCCESS && data == null)
+ _status = ResponseType.SUCCESS_EMPTY;
}
public boolean isSuccessful() {
- return _status != FederatedResponse.Type.ERROR;
+ return _status != ResponseType.ERROR;
}
public String getErrorMessage() {
@@ -103,7 +103,7 @@ public class FederatedResponse implements Serializable {
if ( checkedConstraints != null && !checkedConstraints.isEmpty() ){
this.checkedConstraints = new EnumMap<>(PrivacyLevel.class);
this.checkedConstraints.putAll(checkedConstraints);
- }
+ }
}
public void updateCheckedConstraintsLog(){
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
index afed54b..1eca3a9 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
@@ -32,29 +32,29 @@ import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;
import org.apache.log4j.Logger;
import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
-import org.apache.sysds.runtime.instructions.cp.Data;
-
-import java.util.HashMap;
-import java.util.Map;
+import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
public class FederatedWorker {
protected static Logger log = Logger.getLogger(FederatedWorker.class);
private int _port;
- private int _nrThreads = Integer.parseInt(DMLConfig.DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS);
- private IDSequence _seq = new IDSequence();
- private Map<Long, Data> _vars = new HashMap<>();
-
+ private final ExecutionContext _ec;
+ private final BasicProgramBlock _pb;
+
public FederatedWorker(int port) {
+ _ec = ExecutionContextFactory.createContext();
+ _ec.setAutoCreateVars(true); //w/o createvar inst
+ _pb = new BasicProgramBlock(null);
_port = (port == -1) ?
Integer.parseInt(DMLConfig.DEFAULT_FEDERATED_PORT) : port;
}
public void run() {
log.info("Setting up Federated Worker");
- EventLoopGroup bossGroup = new NioEventLoopGroup(_nrThreads);
- EventLoopGroup workerGroup = new NioEventLoopGroup(_nrThreads);
+ EventLoopGroup bossGroup = new NioEventLoopGroup(1);
+ EventLoopGroup workerGroup = new NioEventLoopGroup(1);
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@@ -65,7 +65,7 @@ public class FederatedWorker {
new ObjectDecoder(Integer.MAX_VALUE,
ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader())))
.addLast("ObjectEncoder", new ObjectEncoder())
- .addLast("FederatedWorkerHandler", new FederatedWorkerHandler(_seq, _vars));
+ .addLast("FederatedWorkerHandler", new FederatedWorkerHandler(_ec, _pb));
}
}).option(ChannelOption.SO_BACKLOG, 128).childOption(ChannelOption.SO_KEEPALIVE, true);
try {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 1e2e6ea..6f6760f 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -27,25 +27,24 @@ import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;
import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.conf.ConfigurationManager;
-import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
-import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
-import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.ResponseType;
+import org.apache.sysds.runtime.instructions.InstructionParser;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.io.IOUtilFunctions;
-import org.apache.sysds.runtime.matrix.data.LibMatrixAgg;
-import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
-import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
-import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.privacy.DMLPrivacyException;
@@ -57,39 +56,39 @@ import org.apache.wink.json4j.JSONObject;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.Arrays;
-import java.util.Map;
public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
protected static Logger log = Logger.getLogger(FederatedWorkerHandler.class);
- private final IDSequence _seq;
- private Map<Long, Data> _vars;
-
- public FederatedWorkerHandler(IDSequence seq, Map<Long, Data> _vars2) {
- _seq = seq;
- _vars = _vars2;
+ private final ExecutionContext _ec;
+ private final BasicProgramBlock _pb;
+
+ public FederatedWorkerHandler(ExecutionContext ec, BasicProgramBlock pb) {
+ _ec = ec;
+ _pb = pb;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
log.debug("Received: " + msg.getClass().getSimpleName());
- FederatedRequest request;
- if (msg instanceof FederatedRequest)
- request = (FederatedRequest) msg;
- else
- throw new DMLRuntimeException("FederatedWorkerHandler: Received object no instance of `FederatedRequest`.");
- FederatedRequest.FedMethod method = request.getMethod();
- log.debug("Received command: " + method.name());
- PrivacyMonitor.setCheckPrivacy(request.checkPrivacy());
- PrivacyMonitor.clearCheckedConstraints();
-
- synchronized (_seq) {
- FederatedResponse response = constructResponse(request);
+ if (!(msg instanceof FederatedRequest[]))
+ throw new DMLRuntimeException("FederatedWorkerHandler: Received object no instance of 'FederatedRequest[]'.");
+ FederatedRequest[] requests = (FederatedRequest[]) msg;
+ FederatedResponse response = null; //last response
+
+ for( int i=0; i<requests.length; i++ ) {
+ FederatedRequest request = requests[i];
+ if( log.isDebugEnabled() )
+ log.debug("Executing command "+(i+1)+"/"+requests.length + ": " + request.getType().name());
+ PrivacyMonitor.setCheckPrivacy(request.checkPrivacy());
+ PrivacyMonitor.clearCheckedConstraints();
+
+ response = executeCommand(request);
conditionalAddCheckedConstraints(request, response);
if (!response.isSuccessful())
- log.error("Method " + method + " failed: " + response.getErrorMessage());
- ctx.writeAndFlush(response).addListener(new CloseListener());
+ log.error("Command " + request.getType() + " failed: " + response.getErrorMessage());
}
+ ctx.writeAndFlush(response).addListener(new CloseListener());
}
private static void conditionalAddCheckedConstraints(FederatedRequest request, FederatedResponse response){
@@ -97,44 +96,41 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
response.setCheckedConstraints(PrivacyMonitor.getCheckedConstraints());
}
- private FederatedResponse constructResponse(FederatedRequest request) {
- FederatedRequest.FedMethod method = request.getMethod();
+ private FederatedResponse executeCommand(FederatedRequest request) {
+ RequestType method = request.getType();
try {
switch (method) {
- case READ_MATRIX:
- return readData(request, Types.DataType.MATRIX);
- case READ_FRAME:
- return readData(request, Types.DataType.FRAME);
- case MATVECMULT:
- return executeMatVecMult(request);
- case TRANSFER:
- return getVariableData(request);
- case AGGREGATE:
- return executeAggregation(request);
- case SCALAR:
- return executeScalarOperation(request);
+ case READ_VAR:
+ return readData(request); //matrix/frame
+ case PUT_VAR:
+ return putVariable(request);
+ case GET_VAR:
+ return getVariable(request);
+ case EXEC_INST:
+ return execInstruction(request);
default:
String message = String.format("Method %s is not supported.", method);
- return new FederatedResponse(FederatedResponse.Type.ERROR, new FederatedWorkerHandlerException(message));
+ return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException(message));
}
}
- catch (DMLPrivacyException | FederatedWorkerHandlerException exception) {
- return new FederatedResponse(FederatedResponse.Type.ERROR, exception);
+ catch (DMLPrivacyException | FederatedWorkerHandlerException ex) {
+ return new FederatedResponse(FederatedResponse.ResponseType.ERROR, ex);
}
- catch (Exception exception) {
- return new FederatedResponse(FederatedResponse.Type.ERROR,
+ catch (Exception ex) {
+ return new FederatedResponse(FederatedResponse.ResponseType.ERROR,
new FederatedWorkerHandlerException("Exception of type "
- + exception.getClass() + " thrown when processing request"));
+ + ex.getClass() + " thrown when processing request", ex));
}
}
- private FederatedResponse readData(FederatedRequest request, Types.DataType dataType) {
- checkNumParams(request.getNumParams(), 1);
+ private FederatedResponse readData(FederatedRequest request) {
+ checkNumParams(request.getNumParams(), 2);
String filename = (String) request.getParam(0);
- return readData(filename, dataType);
+ DataType dt = DataType.valueOf((String)request.getParam(1));
+ return readData(filename, dt, request.getID());
}
- private FederatedResponse readData(String filename, Types.DataType dataType) {
+ private FederatedResponse readData(String filename, Types.DataType dataType, long id) {
MatrixCharacteristics mc = new MatrixCharacteristics();
mc.setBlocksize(ConfigurationManager.getBlocksize());
CacheableData<?> cd;
@@ -147,7 +143,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
break;
default:
// should NEVER happen (if we keep request codes in sync with actual behaviour)
- return new FederatedResponse(FederatedResponse.Type.ERROR,
+ return new FederatedResponse(FederatedResponse.ResponseType.ERROR,
new FederatedWorkerHandlerException("Could not recognize datatype"));
}
@@ -160,7 +156,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
try (BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) {
JSONObject mtd = JSONHelper.parse(br);
if (mtd == null)
- return new FederatedResponse(FederatedResponse.Type.ERROR, new FederatedWorkerHandlerException("Could not parse metadata file"));
+ return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException("Could not parse metadata file"));
mc.setRows(mtd.getLong(DataExpression.READROWPARAM));
mc.setCols(mtd.getLong(DataExpression.READCOLPARAM));
cd = PrivacyPropagator.parseAndSetPrivacyConstraint(cd, mtd);
@@ -173,140 +169,80 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
}
cd.setMetaData(new MetaDataFormat(mc, fmt));
cd.acquireRead();
- cd.refreshMetaData();
+ cd.refreshMetaData(); //in pinned state
cd.release();
-
- long id = _seq.getNextID();
- _vars.put(id, cd);
+
+ //TODO spawn async load of data, otherwise on first access
+ _ec.setVariable(String.valueOf(id), cd);
+
if (dataType == Types.DataType.FRAME) {
FrameObject frameObject = (FrameObject) cd;
- return new FederatedResponse(FederatedResponse.Type.SUCCESS, new Object[] {id, frameObject.getSchema()});
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[] {id, frameObject.getSchema()});
}
- return new FederatedResponse(FederatedResponse.Type.SUCCESS, id);
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, id);
}
-
- private FederatedResponse executeMatVecMult(FederatedRequest request) {
- checkNumParams(request.getNumParams(), 3);
- MatrixBlock vector = (MatrixBlock) request.getParam(0);
- boolean isMatVecMult = (Boolean) request.getParam(1);
- long varID = (Long) request.getParam(2);
-
- return executeMatVecMult(varID, vector, isMatVecMult);
- }
-
- private FederatedResponse executeMatVecMult(long varID, MatrixBlock vector, boolean isMatVecMult) {
- MatrixObject matTo = (MatrixObject) _vars.get(varID);
- matTo = PrivacyMonitor.handlePrivacy(matTo);
- MatrixBlock matBlock1 = matTo.acquireReadAndRelease();
- // TODO other datatypes
- AggregateBinaryOperator ab_op = InstructionUtils
- .getMatMultOperator(OptimizerUtils.getConstrainedNumThreads(0));
- MatrixBlock result = isMatVecMult ?
- matBlock1.aggregateBinaryOperations(matBlock1, vector, new MatrixBlock(), ab_op) :
- vector.aggregateBinaryOperations(vector, matBlock1, new MatrixBlock(), ab_op);
- return new FederatedResponse(FederatedResponse.Type.SUCCESS, result);
- }
-
- private FederatedResponse getVariableData(FederatedRequest request) {
+
+ private FederatedResponse putVariable(FederatedRequest request) {
checkNumParams(request.getNumParams(), 1);
- long varID = (Long) request.getParam(0);
- return getVariableData(varID);
+ String varname = String.valueOf(request.getID());
+ if( _ec.containsVariable(varname) ) {
+ return new FederatedResponse(ResponseType.ERROR,
+ "Variable "+request.getID()+" already existing.");
+ }
+
+ //wrap transferred cache block into cacheable data
+ Data data = null;
+ if( request.getParam(0) instanceof CacheBlock )
+ data = ExecutionContext.createCacheableData((CacheBlock) request.getParam(0));
+ else if( request.getParam(0) instanceof ScalarObject )
+ data = (ScalarObject) request.getParam(0);
+
+ //set variable and construct empty response
+ _ec.setVariable(varname, data);
+ return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
}
-
+
+ private FederatedResponse getVariable(FederatedRequest request) {
+ checkNumParams(request.getNumParams(), 0);
+ if( !_ec.containsVariable(String.valueOf(request.getID())) ) {
+ return new FederatedResponse(ResponseType.ERROR,
+ "Variable "+request.getID()+" does not exist at federated worker.");
+ }
+ //get variable and construct response
+ return getVariableData(request.getID());
+ }
+
private FederatedResponse getVariableData(long varID) {
- Data dataObject = _vars.get(varID);
+ Data dataObject = _ec.getVariable(String.valueOf(varID));
dataObject = PrivacyMonitor.handlePrivacy(dataObject);
switch (dataObject.getDataType()) {
case TENSOR:
- return new FederatedResponse(FederatedResponse.Type.SUCCESS,
- ((TensorObject) dataObject).acquireReadAndRelease());
case MATRIX:
- return new FederatedResponse(FederatedResponse.Type.SUCCESS,
- ((MatrixObject) dataObject).acquireReadAndRelease());
case FRAME:
- return new FederatedResponse(FederatedResponse.Type.SUCCESS,
- ((FrameObject) dataObject).acquireReadAndRelease());
+ return new FederatedResponse(ResponseType.SUCCESS,
+ ((CacheableData<?>) dataObject).acquireReadAndRelease());
case LIST:
- return new FederatedResponse(FederatedResponse.Type.SUCCESS, ((ListObject) dataObject).getData());
- // TODO rest of the possible datatypes
+ return new FederatedResponse(ResponseType.SUCCESS, ((ListObject) dataObject).getData());
+ case SCALAR:
+ return new FederatedResponse(ResponseType.SUCCESS, dataObject);
default:
- return new FederatedResponse(FederatedResponse.Type.ERROR,
- new FederatedWorkerHandlerException("Not possible to send datatype " + dataObject.getDataType().name()));
+ return new FederatedResponse(ResponseType.ERROR,
+ new FederatedWorkerHandlerException("Unsupported return datatype " + dataObject.getDataType().name()));
}
}
-
- private FederatedResponse executeAggregation(FederatedRequest request) {
- checkNumParams(request.getNumParams(), 2);
- AggregateUnaryOperator operator = (AggregateUnaryOperator) request.getParam(0);
- long varID = (Long) request.getParam(1);
- return executeAggregation(varID, operator);
- }
-
- private FederatedResponse executeAggregation(long varID, AggregateUnaryOperator operator) {
- Data dataObject = _vars.get(varID);
- if (dataObject.getDataType() != Types.DataType.MATRIX) {
- return new FederatedResponse(FederatedResponse.Type.ERROR,
- new FederatedWorkerHandlerException("Aggregation only supported for matrices, not for "
- + dataObject.getDataType().name()));
+
+ private FederatedResponse execInstruction(FederatedRequest request) {
+ _pb.getInstructions().clear();
+ _pb.getInstructions().add(InstructionParser
+ .parseSingleInstruction((String)request.getParam(0)));
+ try {
+ _pb.execute(_ec); //execute single instruction
}
- MatrixObject matrixObject = (MatrixObject) dataObject;
- matrixObject = PrivacyMonitor.handlePrivacy(matrixObject);
- MatrixBlock matrixBlock = matrixObject.acquireRead();
- // create matrix for calculation with correction
- MatrixCharacteristics mc = new MatrixCharacteristics();
- // find out the characteristics after aggregation
- operator.indexFn.computeDimension(matrixObject.getDataCharacteristics(), mc);
- // make outBlock right size
- int outNumRows = (int) mc.getRows();
- int outNumCols = (int) mc.getCols();
- if (operator.aggOp.existsCorrection()) {
- // add rows for correction
- int numMissing = operator.aggOp.correction.getNumRemovedRowsColumns();
- if (operator.aggOp.correction.isRows())
- outNumRows += numMissing;
- else
- outNumCols += numMissing;
+ catch(Exception ex) {
+ ex.printStackTrace();
+ return new FederatedResponse(ResponseType.ERROR, ex.getMessage());
}
- MatrixBlock ret = new MatrixBlock(outNumRows, outNumCols, operator.aggOp.initialValue);
- LibMatrixAgg.aggregateUnaryMatrix(matrixBlock, ret, operator);
- // result block without correction
- ret.dropLastRowsOrColumns(operator.aggOp.correction);
- return new FederatedResponse(FederatedResponse.Type.SUCCESS, ret);
- }
-
- private FederatedResponse executeScalarOperation(FederatedRequest request) {
- checkNumParams(request.getNumParams(), 2);
- ScalarOperator operator = (ScalarOperator) request.getParam(0);
- long varID = (Long) request.getParam(1);
- return executeScalarOperation(varID, operator);
- }
-
- private FederatedResponse executeScalarOperation(long varID, ScalarOperator operator) {
- Data dataObject = _vars.get(varID);
- dataObject = PrivacyMonitor.handlePrivacy(dataObject);
- if (dataObject.getDataType() != Types.DataType.MATRIX) {
- return new FederatedResponse(FederatedResponse.Type.ERROR,
- new FederatedWorkerHandlerException("FederatedWorkerHandler: ScalarOperator dont support "
- + dataObject.getDataType().name()));
- }
-
- MatrixObject matrixObject = (MatrixObject) dataObject;
- MatrixBlock inBlock = matrixObject.acquireRead();
- MatrixBlock retBlock = inBlock.scalarOperations(operator, new MatrixBlock());
- return new FederatedResponse(FederatedResponse.Type.SUCCESS, retBlock);
- }
-
- @SuppressWarnings("unused")
- private FederatedResponse createMatrixObject(MatrixBlock result) {
- MatrixObject resTo = new MatrixObject(Types.ValueType.FP64, OptimizerUtils.getUniqueTempFileName());
- MetaDataFormat metadata = new MetaDataFormat(
- new MatrixCharacteristics(result.getNumRows(), result.getNumColumns()), FileFormat.BINARY);
- resTo.setMetaData(metadata);
- resTo.acquireModify(result);
- resTo.release();
- long result_var = _seq.getNextID();
- _vars.put(result_var, resTo);
- return new FederatedResponse(FederatedResponse.Type.SUCCESS, result_var);
+ return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
}
private static void checkNumParams(int actual, int... expected) {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandlerException.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandlerException.java
index 79c1a6b..77768b3 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandlerException.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandlerException.java
@@ -37,4 +37,8 @@ public class FederatedWorkerHandlerException extends RuntimeException {
public FederatedWorkerHandlerException(String msg) {
super(msg);
}
+
+ public FederatedWorkerHandlerException(String msg, Throwable t) {
+ super(msg, t);
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
new file mode 100644
index 0000000..d2e2300
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.TreeMap;
+import java.util.concurrent.Future;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+public class FederationMap
+{
+ private long _ID = -1;
+ private final Map<FederatedRange, FederatedData> _fedMap;
+
+ public FederationMap(Map<FederatedRange, FederatedData> fedMap) {
+ this(-1, fedMap);
+ }
+
+ public FederationMap(long ID, Map<FederatedRange, FederatedData> fedMap) {
+ _ID = ID;
+ _fedMap = fedMap;
+ }
+
+ public long getID() {
+ return _ID;
+ }
+
+ public boolean isInitialized() {
+ return _ID >= 0;
+ }
+
+ public FederatedRequest broadcast(CacheableData<?> data) {
+ //prepare single request for all federated data
+ long id = FederationUtils.getNextFedDataID();
+ CacheBlock cb = data.acquireReadAndRelease();
+ return new FederatedRequest(RequestType.PUT_VAR, id, cb);
+ }
+
+ public FederatedRequest broadcast(ScalarObject scalar) {
+ //prepare single request for all federated data
+ long id = FederationUtils.getNextFedDataID();
+ return new FederatedRequest(RequestType.PUT_VAR, id, scalar);
+ }
+
+ public FederatedRequest[] broadcastSliced(CacheableData<?> data, boolean transposed) {
+ //prepare separate requests for different slices
+ long id = FederationUtils.getNextFedDataID();
+ CacheBlock cb = data.acquireReadAndRelease();
+ List<FederatedRequest> ret = new ArrayList<>();
+ for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
+ int rl = transposed ? 0 : e.getKey().getBeginDimsInt()[0];
+ int ru = transposed ? cb.getNumRows()-1 : e.getKey().getEndDimsInt()[0]-1;
+ int cl = transposed ? e.getKey().getBeginDimsInt()[0] : 0;
+ int cu = transposed ? e.getKey().getEndDimsInt()[0]-1 : cb.getNumColumns()-1;
+ CacheBlock tmp = cb.slice(rl, ru, cl, cu, new MatrixBlock());
+ ret.add(new FederatedRequest(RequestType.PUT_VAR, id, tmp));
+ }
+ return ret.toArray(new FederatedRequest[0]);
+ }
+
+ @SuppressWarnings("unchecked")
+ public Future<FederatedResponse>[] execute(FederatedRequest... fr) {
+ List<Future<FederatedResponse>> ret = new ArrayList<>();
+ for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
+ ret.add(e.getValue().executeFederatedOperation(fr));
+ return ret.toArray(new Future[0]);
+ }
+
+ @SuppressWarnings("unchecked")
+ public Future<FederatedResponse>[] execute(FederatedRequest[] frSlices, FederatedRequest... fr) {
+ //executes step1[] - step 2 - ... step4 (only first step federated-data-specific)
+ List<Future<FederatedResponse>> ret = new ArrayList<>();
+ int pos = 0;
+ for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
+ ret.add(e.getValue().executeFederatedOperation(addAll(frSlices[pos++], fr)));
+ return ret.toArray(new Future[0]);
+ }
+
+ public List<Pair<FederatedRange, Future<FederatedResponse>>> requestFederatedData() {
+ if( !isInitialized() )
+ throw new DMLRuntimeException("Federated matrix read only supported on initialized FederatedData");
+
+ List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = new ArrayList<>();
+ FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, _ID);
+ for(Map.Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
+ readResponses.add(new ImmutablePair<>(e.getKey(),
+ e.getValue().executeFederatedOperation(request)));
+ return readResponses;
+ }
+
+ public void cleanup(long... id) {
+ FederatedRequest request = new FederatedRequest(RequestType.EXEC_INST, -1,
+ VariableCPInstruction.prepareRemoveInstruction(id).toString());
+ for(FederatedData fd : _fedMap.values())
+ fd.executeFederatedOperation(request);
+ }
+
+ private static FederatedRequest[] addAll(FederatedRequest a, FederatedRequest[] b) {
+ FederatedRequest[] ret = new FederatedRequest[b.length + 1];
+ ret[0] = a; System.arraycopy(b, 0, ret, 1, b.length);
+ return ret;
+ }
+
+ public FederationMap copyWithNewID() {
+ return copyWithNewID(FederationUtils.getNextFedDataID());
+ }
+
+ public FederationMap copyWithNewID(long id) {
+ Map<FederatedRange, FederatedData> map = new TreeMap<>();
+ //TODO handling of file path, but no danger as never written
+ for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() )
+ map.put(new FederatedRange(e.getKey()), new FederatedData(e.getValue(), id));
+ return new FederationMap(id, map);
+ }
+
+ public FederationMap rbind(long offset, FederationMap that) {
+ for( Entry<FederatedRange, FederatedData> e : that._fedMap.entrySet() ) {
+ _fedMap.put(
+ new FederatedRange(e.getKey()).shift(offset, 0),
+ new FederatedData(e.getValue(), _ID));
+ }
+ return this;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
new file mode 100644
index 0000000..ab0b3aa
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated;
+
+
+import java.util.Arrays;
+import java.util.concurrent.Future;
+
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysds.runtime.functionobjects.KahanFunction;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+
+public class FederationUtils {
+ private static final IDSequence _idSeq = new IDSequence();
+
+ public static long getNextFedDataID() {
+ return _idSeq.getNextID();
+ }
+
+ public static FederatedRequest callInstruction(String inst, CPOperand varOldOut, CPOperand[] varOldIn, long[] varNewIn) {
+ //TODO better and safe replacement of operand names --> instruction utils
+ long id = getNextFedDataID();
+ String linst = inst.replace(ExecType.SPARK.name(), ExecType.CP.name());
+ linst = linst.replace(Lop.OPERAND_DELIMITOR+varOldOut.getName(), Lop.OPERAND_DELIMITOR+String.valueOf(id));
+ for(int i=0; i<varOldIn.length; i++)
+ if( varOldIn[i] != null )
+ linst = linst.replace(Lop.OPERAND_DELIMITOR+varOldIn[i].getName(),
+ Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i]));
+ return new FederatedRequest(RequestType.EXEC_INST, id, linst);
+ }
+
+ public static MatrixBlock aggAdd(Future<FederatedResponse>[] ffr) {
+ try {
+ BinaryOperator bop = InstructionUtils.parseBinaryOperator("+");
+ MatrixBlock ret = (MatrixBlock) (ffr[0].get().getData()[0]);
+ for (int i=1; i<ffr.length; i++) {
+ MatrixBlock tmp = (MatrixBlock) (ffr[i].get().getData()[0]);
+ ret.binaryOperationsInPlace(bop, tmp);
+ }
+ return ret;
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+
+ public static MatrixBlock[] getResults(Future<FederatedResponse>[] ffr) {
+ try {
+ MatrixBlock[] ret = new MatrixBlock[ffr.length];
+ for(int i=0; i<ffr.length; i++)
+ ret[i] = (MatrixBlock) ffr[i].get().getData()[0];
+ return ret;
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+
+ public static MatrixBlock rbind(Future<FederatedResponse>[] ffr) {
+ // TODO handle non-contiguous cases
+ try {
+ MatrixBlock[] tmp = getResults(ffr);
+ return tmp[0].append(
+ Arrays.copyOfRange(tmp, 1, tmp.length),
+ new MatrixBlock(), false);
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+
+ public static ScalarObject aggScalar(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr) {
+ if( !(aop.aggOp.increOp.fn instanceof KahanFunction) ) {
+ throw new DMLRuntimeException("Unsupported aggregation operator: "
+ + aop.aggOp.increOp.getClass().getSimpleName());
+ }
+ //compute scalar sum of partial aggregates
+ try {
+ double sum = 0; //uak+, uasqk+
+ for( Future<FederatedResponse> fr : ffr )
+ sum += ((ScalarObject)fr.get().getData()[0]).getDoubleValue();
+ return new DoubleObject(sum);
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+
+ public static MatrixBlock aggMatrix(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr) {
+ if( !(aop.aggOp.increOp.fn instanceof KahanFunction) ) {
+ throw new DMLRuntimeException("Unsupported aggregation operator: "
+ + aop.aggOp.increOp.getClass().getSimpleName());
+ }
+
+ //assumes full row partitions for row and col aggregates
+ return aop.isRowAggregate() ? rbind(ffr) : aggAdd(ffr);
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/LibFederatedAgg.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/LibFederatedAgg.java
deleted file mode 100644
index feb5a68..0000000
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/LibFederatedAgg.java
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.runtime.controlprogram.federated;
-
-import org.apache.commons.lang3.tuple.ImmutablePair;
-import org.apache.commons.lang3.tuple.Pair;
-import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysds.runtime.functionobjects.KahanFunction;
-import org.apache.sysds.runtime.functionobjects.ValueFunction;
-import org.apache.sysds.runtime.instructions.cp.KahanObject;
-import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.data.MatrixValue;
-import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
-import org.apache.sysds.runtime.meta.MatrixCharacteristics;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.Future;
-
-/**
- * Library for federated aggregation operations.
- * <p>
- * This libary covers the following opcodes:
- * uak+
- */
-public class LibFederatedAgg
-{
- public static MatrixBlock aggregateUnaryMatrix(MatrixObject federatedMatrix, AggregateUnaryOperator operator) {
- // find out the characteristics after aggregation
- MatrixCharacteristics mc = new MatrixCharacteristics();
- operator.indexFn.computeDimension(federatedMatrix.getDataCharacteristics(), mc);
- // make outBlock right size
- MatrixBlock ret = new MatrixBlock((int) mc.getRows(), (int) mc.getCols(), operator.aggOp.initialValue);
- List<Pair<FederatedRange, Future<FederatedResponse>>> idResponsePairs = new ArrayList<>();
- // distribute aggregation operation to all workers
- for (Map.Entry<FederatedRange, FederatedData> entry : federatedMatrix.getFedMapping().entrySet()) {
- FederatedData fedData = entry.getValue();
- if (!fedData.isInitialized())
- throw new DMLRuntimeException("Not all FederatedData was initialized for federated matrix");
- Future<FederatedResponse> future = fedData.executeFederatedOperation(
- new FederatedRequest(FederatedRequest.FedMethod.AGGREGATE, operator), true);
- idResponsePairs.add(new ImmutablePair<>(entry.getKey(), future));
- }
- try {
- //TODO replace with block operations
- for (Pair<FederatedRange, Future<FederatedResponse>> idResponsePair : idResponsePairs) {
- FederatedRange range = idResponsePair.getLeft();
- FederatedResponse federatedResponse = idResponsePair.getRight().get();
- int[] beginDims = range.getBeginDimsInt();
- MatrixBlock mb = (MatrixBlock) federatedResponse.getData()[0];
- // TODO performance optimizations
- MatrixValue.CellIndex cellIndex = new MatrixValue.CellIndex(0, 0);
- ValueFunction valueFn = operator.aggOp.increOp.fn;
- // Add worker response to resultBlock
- for (int r = 0; r < mb.getNumRows(); r++)
- for (int c = 0; c < mb.getNumColumns(); c++) {
- // Get the output index where the result should be placed by the index function
- // -> map input row/col to output row/col
- cellIndex.set(r + beginDims[0], c + beginDims[1]);
- operator.indexFn.execute(cellIndex, cellIndex);
- int resultRow = cellIndex.row;
- int resultCol = cellIndex.column;
- double newValue;
- if (valueFn instanceof KahanFunction) {
- // TODO iterate along correct axis to use correction correctly
- // temporary solution to execute correct overloaded method
- KahanObject kobj = new KahanObject(ret.quickGetValue(resultRow, resultCol), 0);
- newValue = ((KahanObject) valueFn.execute(kobj, mb.quickGetValue(r, c)))._sum;
- }
- else {
- // TODO special handling for `ValueFunction`s which do not implement `.execute(double, double)`
- // "Add" two partial calculations together with ValueFunction
- newValue = valueFn.execute(ret.quickGetValue(resultRow, resultCol), mb.quickGetValue(r, c));
- }
- ret.quickSetValue(resultRow, resultCol, newValue);
- }
- }
- }
- catch (Exception e) {
- throw new DMLRuntimeException("Federated binary aggregation failed", e);
- }
- return ret;
- }
-}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/LibFederatedAppend.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/LibFederatedAppend.java
deleted file mode 100644
index 58ad5da..0000000
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/LibFederatedAppend.java
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.runtime.controlprogram.federated;
-
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysds.runtime.meta.DataCharacteristics;
-
-import java.util.Map;
-import java.util.TreeMap;
-
-public class LibFederatedAppend {
- public static MatrixObject federateAppend(MatrixObject matObject1, MatrixObject matObject2,
- MatrixObject matObjectRet, boolean cbind)
- {
- Map<FederatedRange, FederatedData> fedMapping = new TreeMap<>();
- DataCharacteristics dc = matObjectRet.getDataCharacteristics();
- if (cbind) {
- // check for same amount of rows for matObject1 and matObject2 should have been checked before call
- dc.setRows(matObject1.getNumRows());
- // added because cbind
- long columnsLeftMat = matObject1.getNumColumns();
- dc.setCols(columnsLeftMat + matObject2.getNumColumns());
-
- Map<FederatedRange, FederatedData> fedMappingLeft = matObject1.getFedMapping();
- for (Map.Entry<FederatedRange, FederatedData> entry : fedMappingLeft.entrySet()) {
- // note that FederatedData should not change its varId once set
- fedMapping.put(new FederatedRange(entry.getKey()), entry.getValue());
- }
- Map<FederatedRange, FederatedData> fedMappingRight = matObject2.getFedMapping();
- for (Map.Entry<FederatedRange, FederatedData> entry : fedMappingRight.entrySet()) {
- // add offset due to cbind
- FederatedRange range = new FederatedRange(entry.getKey());
- range.setBeginDim(1, columnsLeftMat + range.getBeginDims()[1]);
- range.setEndDim(1, columnsLeftMat + range.getEndDims()[1]);
- fedMapping.put(range, entry.getValue());
- }
- }
- else {
- // check for same amount of cols for matObject1 and matObject2 should have been checked before call
- dc.setCols(matObject1.getNumColumns());
- // added because rbind
- long rowsUpperMat = matObject1.getNumRows();
- dc.setRows(rowsUpperMat + matObject2.getNumRows());
-
- Map<FederatedRange, FederatedData> fedMappingUpper = matObject1.getFedMapping();
- for (Map.Entry<FederatedRange, FederatedData> entry : fedMappingUpper.entrySet()) {
- // note that FederatedData should not change its varId once set
- fedMapping.put(new FederatedRange(entry.getKey()), entry.getValue());
- }
- Map<FederatedRange, FederatedData> fedMappingLower = matObject2.getFedMapping();
- for (Map.Entry<FederatedRange, FederatedData> entry : fedMappingLower.entrySet()) {
- // add offset due to rbind
- FederatedRange range = new FederatedRange(entry.getKey());
- range.setBeginDim(0, rowsUpperMat + range.getBeginDims()[0]);
- range.setEndDim(0, rowsUpperMat + range.getEndDims()[0]);
- fedMapping.put(range, entry.getValue());
- }
- }
- matObjectRet.setFedMapping(fedMapping);
- dc.setNonZeros(matObject1.getNnz() + matObject2.getNnz());
- return matObjectRet;
- }
-}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java
index d640313..db70d6b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java
@@ -115,7 +115,7 @@ public final class MatrixIndexingCPInstruction extends IndexingCPInstruction {
resultBlock.examSparsity();
//unpin output
- ec.setMatrixOutput(output.getName(), resultBlock, updateType, getExtendedOpcode());
+ ec.setMatrixOutput(output.getName(), resultBlock, updateType);
}
else
throw new DMLRuntimeException("Invalid opcode (" + opcode +") encountered in MatrixIndexingCPInstruction.");
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index 2082047..f7f3698 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.cp;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang.StringUtils;
@@ -1102,6 +1103,12 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace
}
}
+ public static Instruction prepareRemoveInstruction(long... varName) {
+ String[] tmp = new String[varName.length];
+ Arrays.setAll(tmp, i -> String.valueOf(varName[i]));
+ return prepareRemoveInstruction(tmp);
+ }
+
public static Instruction prepareRemoveInstruction(String... varNames) {
StringBuilder sb = new StringBuilder();
sb.append("CP");
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 2c064fc..3fe1004 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -19,31 +19,18 @@
package org.apache.sysds.runtime.instructions.fed;
-import org.apache.commons.lang3.tuple.ImmutablePair;
-import org.apache.commons.lang3.tuple.MutablePair;
-import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
-import org.apache.sysds.runtime.functionobjects.Multiply;
-import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
-import org.apache.sysds.runtime.util.CommonThreadPool;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
@@ -56,329 +43,53 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
public static AggregateBinaryFEDInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
-
- if(!opcode.equalsIgnoreCase("ba+*")) {
+ if(!opcode.equalsIgnoreCase("ba+*"))
throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
- }
InstructionUtils.checkNumFields(parts, 4);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
int k = Integer.parseInt(parts[4]);
-
return new AggregateBinaryFEDInstruction(
InstructionUtils.getMatMultOperator(k), in1, in2, out, opcode, str);
}
@Override
public void processInstruction(ExecutionContext ec) {
- //get inputs
- MatrixObject mo1 = ec.getMatrixObject(input1.getName());
- MatrixObject mo2 = ec.getMatrixObject(input2.getName());
- MatrixObject out = ec.getMatrixObject(output.getName());
-
- // compute matrix-vector multiplication
- AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
- if (mo1.isFederated() && mo2.getNumColumns() == 1) {// MV
- MatrixBlock vector = mo2.acquireRead();
- federatedAggregateBinaryMV(mo1, vector, out, ab_op, true);
- mo2.release();
- }
- else if (mo2.isFederated() && mo1.getNumRows() == 1) {// VM
- MatrixBlock vector = mo1.acquireRead();
- federatedAggregateBinaryMV(mo2, vector, out, ab_op, false);
- mo1.release();
- }
- else // MM
- federatedAggregateBinary(mo1, mo2, out);
- }
-
- /**
- * Performs a federated binary aggregation (currently only MV and VM is supported).
- *
- * @param mo1 the first matrix object
- * @param mo2 the other matrix object
- * @param out output matrix object
- */
- private static void federatedAggregateBinary(MatrixObject mo1, MatrixObject mo2, MatrixObject out) {
- boolean distributeCols = false;
- // if distributeCols = true we distribute cols of mo2 and do a MV multiplications, otherwise we
- // distribute rows of mo1 and do VM multiplications
- if (mo1.isFederated() && mo2.isFederated()) {
- // both are federated -> distribute smaller matrix
- // TODO do more in depth checks like: how many federated workers, how big is the actual data we send and so on
- // maybe once we track number of non zeros we could use that to get a better estimation of how much data
- // will be requested?
- distributeCols = mo2.getNumColumns() * mo2.getNumRows() < mo1.getNumColumns() * mo1.getNumRows();
- }
- else if (mo2.isFederated() && !mo1.isFederated()) {
- // Distribute mo1 which is not federated
- distributeCols = true;
- }
- // TODO performance if both matrices are federated
- Map<FederatedRange, FederatedData> mapping = distributeCols ? mo1.getFedMapping() : mo2.getFedMapping();
- MatrixBlock matrixBlock = distributeCols ? mo2.acquireRead() : mo1.acquireRead();
- ExecutorService pool = CommonThreadPool.get(mapping.size());
- ArrayList<Pair<FederatedRange, MatrixBlock>> results = new ArrayList<>();
- ArrayList<FederatedMMTask> tasks = new ArrayList<>();
- for (Map.Entry<FederatedRange, FederatedData> fedMap : mapping.entrySet()) {
- // this resultPair will contain both position of partial result and the partial result itself of the operations
- MutablePair<FederatedRange, MatrixBlock> resultPair = new MutablePair<>();
- // they all get references to the real block, the task slices out the needed part and does the
- // multiplication, therefore they can share the object since we use it immutably
- tasks.add(new FederatedMMTask(fedMap.getKey(), fedMap.getValue(), resultPair, matrixBlock, distributeCols));
- results.add(resultPair);
- }
- CommonThreadPool.invokeAndShutdown(pool, tasks);
- (distributeCols?mo2:mo1).release();
-
- // combine results
- if (mo1.getNumRows() > Integer.MAX_VALUE || mo2.getNumColumns() > Integer.MAX_VALUE) {
- throw new DMLRuntimeException("Federated matrix is too large for federated distribution");
- }
- out.acquireModify(combinePartialMMResults(results, (int) mo1.getNumRows(), (int) mo2.getNumColumns()));
- out.release();
- }
-
- private static MatrixBlock combinePartialMMResults(ArrayList<Pair<FederatedRange, MatrixBlock>> results,
- int rows, int cols) {
- // TODO support large blocks with > int size
- MatrixBlock resultBlock = new MatrixBlock(rows, cols, false);
- for (Pair<FederatedRange, MatrixBlock> partialResult : results) {
- FederatedRange range = partialResult.getLeft();
- MatrixBlock partialBlock = partialResult.getRight();
- int[] dimsLower = range.getBeginDimsInt();
- int[] dimsUpper = range.getEndDimsInt();
- resultBlock.copy(dimsLower[0], dimsUpper[0] - 1, dimsLower[1], dimsUpper[1] - 1, partialBlock, false);
- }
- resultBlock.recomputeNonZeros();
- return resultBlock;
- }
-
- /**
- * Performs a federated binary aggregation on a Matrix and Vector.
- *
- * @param fedMo the federated matrix object
- * @param vector the vector
- * @param output the output matrix object
- * @param op the operation
- * @param matrixVectorOp true if matrix vector operation, false if vector matrix op
- */
- public static void federatedAggregateBinaryMV(MatrixObject fedMo, MatrixBlock vector, MatrixObject output,
- AggregateBinaryOperator op, boolean matrixVectorOp) {
- if (!(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn instanceof Plus))
- throw new DMLRuntimeException("Only matrix-vector is supported for federated binary aggregation");
- // fixed implementation only for mv, vm multiply and plus
- // TODO move this to a Lib class?
-
- // create output matrix
- MatrixBlock resultBlock;
- // if we change the order of parameters, so VM instead of MV, the output has different dimensions
- if (!matrixVectorOp) {
- output.getDataCharacteristics().setRows(1).setCols(fedMo.getNumColumns());
- resultBlock = new MatrixBlock(1, (int) fedMo.getNumColumns(), false);
- }
- else {
- output.getDataCharacteristics().setRows(fedMo.getNumRows()).setCols(1);
- resultBlock = new MatrixBlock((int) fedMo.getNumRows(), 1, false);
- }
- List<Pair<FederatedRange, Future<FederatedResponse>>> idResponsePairs = new ArrayList<>();
- // TODO parallel for loop (like on lines 125-136)
- for (Map.Entry<FederatedRange, FederatedData> entry : fedMo.getFedMapping().entrySet()) {
- FederatedRange range = entry.getKey();
- FederatedData fedData = entry.getValue();
- Future<FederatedResponse> future = executeMVMultiply(range, fedData, vector, matrixVectorOp);
- idResponsePairs.add(new ImmutablePair<>(range, future));
- }
- try {
- for (Pair<FederatedRange, Future<FederatedResponse>> idResponsePair : idResponsePairs) {
- FederatedRange range = idResponsePair.getLeft();
- FederatedResponse federatedResponse = idResponsePair.getRight().get();
- combinePartialMVResults(range, federatedResponse, resultBlock, matrixVectorOp);
- }
- }
- catch (Exception e) {
- throw new DMLRuntimeException("Federated binary aggregation failed", e);
- }
- long nnz = resultBlock.recomputeNonZeros();
- output.acquireModify(resultBlock);
- output.getDataCharacteristics().setNonZeros(nnz);
- output.release();
- }
-
- private static void combinePartialMVResults(FederatedRange range,
- FederatedResponse federatedResponse, MatrixBlock resultBlock, boolean matrixVectorOp)
- {
- try {
- int[] beginDims = range.getBeginDimsInt();
- MatrixBlock mb = (MatrixBlock) federatedResponse.getData()[0];
- // TODO performance optimizations
- // TODO Improve Vector Matrix multiplication accuracy: An idea would be to make use of kahan plus here,
- // this should improve accuracy a bit, although we still lose out on the small error lost on the worker
- // without having to return twice the amount of data (value + sum error)
- // Add worker response to resultBlock
- for (int r = 0; r < mb.getNumRows(); r++)
- for (int c = 0; c < mb.getNumColumns(); c++) {
- int resultRow = r + (!matrixVectorOp ? 0 : beginDims[0]);
- int resultColumn = c + (!matrixVectorOp ? beginDims[1] : 0);
- resultBlock.quickSetValue(resultRow, resultColumn,
- resultBlock.quickGetValue(resultRow, resultColumn) + mb.quickGetValue(r, c));
- }
- } catch (Exception e){
- throw new DMLRuntimeException("Combine partial results from federated matrix failed.", e);
- }
- }
-
- private static Future<FederatedResponse> executeMVMultiply(FederatedRange range,
- FederatedData fedData, MatrixBlock vector, boolean matrixVectorOp)
- {
- if (!fedData.isInitialized()) {
- throw new DMLRuntimeException("Not all FederatedData was initialized for federated matrix");
- }
- int[] beginDimsInt = range.getBeginDimsInt();
- int[] endDimsInt = range.getEndDimsInt();
- // params for federated request
- List<Object> params = new ArrayList<>();
- // we broadcast the needed part of the small vector
- MatrixBlock vectorSlice;
- if (!matrixVectorOp) {
- // if we the size already is ok, we do not have to copy a slice
- int length = endDimsInt[0] - beginDimsInt[0];
- if (vector.getNumColumns() == length) {
- vectorSlice = vector;
- }
- else {
- vectorSlice = new MatrixBlock(1, length, false);
- vector.slice(0, 0, beginDimsInt[0], endDimsInt[0] - 1, vectorSlice);
- }
- }
- else {
- int length = endDimsInt[1] - beginDimsInt[1];
- if (vector.getNumRows() == length) {
- vectorSlice = vector;
- }
- else {
- vectorSlice = new MatrixBlock(length, 1, false);
- vector.slice(beginDimsInt[1], endDimsInt[1] - 1, 0, 0, vectorSlice);
- }
- }
- params.add(vectorSlice);
- params.add(matrixVectorOp); // if is matrix vector multiplication true, otherwise false
- return fedData.executeFederatedOperation(
- new FederatedRequest(FederatedRequest.FedMethod.MATVECMULT, params), true);
- }
-
- private static class FederatedMMTask implements Callable<Void> {
- private FederatedRange _range;
- private FederatedData _data;
- private MutablePair<FederatedRange, MatrixBlock> _result;
- private MatrixBlock _otherMatrix;
- private boolean _distributeCols;
-
- public FederatedMMTask(FederatedRange range, FederatedData fedData,
- MutablePair<FederatedRange, MatrixBlock> result, MatrixBlock otherMatrix, boolean distributeCols)
- {
- _range = range;
- _data = fedData;
- _result = result;
- _otherMatrix = otherMatrix;
- _distributeCols = distributeCols;
- }
-
- @Override
- public Void call() throws Exception {
- if (_distributeCols)
- executeColWiseMVMultiplication();
- else
- executeRowWiseVMMultiplications();
- return null;
- }
-
- /**
- * Distribute the non or smaller federated block as row vectors to the federated worker and do row number of
- * times a vector-matrix multiplication. Non or smaller federated block is left operand.
- *
- * @throws InterruptedException if .get() on federated response future fails -> interrupted
- * @throws ExecutionException if .get() on federated response future fails -> execution failed
- */
- private void executeRowWiseVMMultiplications() throws InterruptedException, ExecutionException {
- MatrixBlock result;
- // TODO support large matrices with long indexes
- int[] beginDims = _range.getBeginDimsInt();
- int[] endDims = _range.getEndDimsInt();
- // we take all rows but only the columns between the rows of the federated block of the other block (left
- // hand side of the calculation).
- int rowsBeginOtherBlock = 0;
- int colsBeginOtherBlock = beginDims[0];
- int rowsEndOtherBlock = _otherMatrix.getNumRows();
- int colsEndOtherBlock = endDims[0];
- // Size of partial result block for vm is rows of otherBlock * cols of federatedData
- result = new MatrixBlock(rowsEndOtherBlock - rowsBeginOtherBlock, endDims[1] - beginDims[1], false);
- // Set range of output in result block, rows are the number of rows of the other block, while columns
- // are the number of columns of our federated data
- _result.setLeft(new FederatedRange(new long[] {rowsBeginOtherBlock, beginDims[1]},
- new long[] {rowsEndOtherBlock, endDims[1]}));
- // vector which we will distribute otherBlock.rows number of times
- MatrixBlock vec = new MatrixBlock(1, colsEndOtherBlock - colsBeginOtherBlock, false);
- for (int r = rowsBeginOtherBlock; r < rowsEndOtherBlock; r++) {
- // slice row vector out of other matrix which we will send to federated worker
- _otherMatrix.slice(r, r, colsBeginOtherBlock, colsEndOtherBlock - 1, vec);
- // TODO experiment if sending multiple requests at the same time to the same worker increases
- // performance (remove get and do multithreaded?)
- FederatedResponse response = executeMVMultiply(_range, _data, vec, _distributeCols).get();
- try{
- result.copy(r, r, 0, endDims[1] - beginDims[1] - 1, (MatrixBlock) response.getData()[0], true);
- } catch (Exception e) {
- throw new DMLRuntimeException(
- "Federated Matrix-Matrix Multiplication failed: ", e);
- }
- }
- _result.setRight(result);
- }
+ MatrixObject mo1 = ec.getMatrixObject(input1);
+ MatrixObject mo2 = ec.getMatrixObject(input2);
- /**
- * Distribute the non or smaller federated block as col vectors to the federated worker and do column number of
- * times a matrix-vector multiplication. Non or smaller federated block is right operand.
- *
- * @throws InterruptedException if .get() on federated response future fails -> interrupted
- * @throws ExecutionException if .get() on federated response future fails -> execution failed
- */
- private void executeColWiseMVMultiplication()
- throws InterruptedException, ExecutionException {
- MatrixBlock result;
- // TODO support large matrices with long indexes
- int[] beginDims = _range.getBeginDimsInt();
- int[] endDims = _range.getEndDimsInt();
- // we take all columns but only the rows between the columns of the federated block of the other block (right
- // hand side of the calculation).
- int rowsBeginOtherBlock = beginDims[1];
- int colsBeginOtherBlock = 0;
- int rowsEndOtherBlock = endDims[1];
- int colsEndOtherBlock = _otherMatrix.getNumColumns();
- // Size of partial result block for mv is rows of federated block * cols of other block
- result = new MatrixBlock(endDims[0] - beginDims[0], colsEndOtherBlock - colsBeginOtherBlock, false);
- // Set range of output in result block, rows are the number of rows of the federated data, while columns
- // are the number of columns of the other block
- _result.setLeft(new FederatedRange(new long[] {beginDims[0], colsBeginOtherBlock},
- new long[] {endDims[0], colsEndOtherBlock}));
- // vector which we will distribute otherBlock.cols number of times
- MatrixBlock vec = new MatrixBlock(rowsEndOtherBlock - rowsBeginOtherBlock, 1, false);
- for (int c = colsBeginOtherBlock; c < colsEndOtherBlock; c++) {
- // slice column vector out of other matrix which we will send to federated worker
- _otherMatrix.slice(rowsBeginOtherBlock, rowsEndOtherBlock - 1, c, c, vec);
- // TODO experiment if sending multiple requests at the same time to the same worker increases
- // performance
- FederatedResponse response = executeMVMultiply(_range, _data, vec, _distributeCols).get();
- try {
- result.copy(0, endDims[0] - beginDims[0] - 1, c, c, (MatrixBlock) response.getData()[0], true);
- } catch (Exception e){
- throw new DMLRuntimeException(
- "Federated Matrix-Matrix Multiplication failed: ", e);
- }
-
- }
- _result.setRight(result);
+ //#1 federated matrix-vector multiplication
+ if(mo1.isFederated()) { // MV + MM
+ //construct commands: broadcast rhs, fed mv, retrieve results
+ FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
+ FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(fr1, fr2, fr3);
+ MatrixBlock ret = FederationUtils.rbind(tmp);
+ mo1.getFedMapping().cleanup(fr1.getID(), fr2.getID());
+ ec.setMatrixOutput(output.getName(), ret);
+ //TODO should remain federated matrix (no need for agg)
+ }
+ //#2 vector - federated matrix multiplication
+ else if (mo2.isFederated()) {// VM + MM
+ //construct commands: broadcast rhs, fed mv, retrieve results
+ FederatedRequest[] fr1 = mo2.getFedMapping().broadcastSliced(mo1, true);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2}, new long[]{fr1[0].getID(), mo2.getFedMapping().getID()});
+ FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(fr1, fr2, fr3);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ mo2.getFedMapping().cleanup(fr1[0].getID(), fr2.getID());
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ else { //other combinations
+ throw new DMLRuntimeException("Federated AggregateBinary not supported with the "
+ + "following federated objects: "+mo1.isFederated()+" "+mo2.isFederated());
}
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 3b9f57c..e5dd81e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -19,23 +19,22 @@
package org.apache.sysds.runtime.instructions.fed;
-import org.apache.sysds.common.Types;
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.conf.ConfigurationManager;
-import org.apache.sysds.runtime.DMLRuntimeException;
+import java.util.concurrent.Future;
+
+import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.LibFederatedAgg;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.instructions.cp.DoubleObject;
-import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
- private AggregateUnaryFEDInstruction(AggregateUnaryOperator auop, AggregateOperator aop, CPOperand in,
+ private AggregateUnaryFEDInstruction(AggregateUnaryOperator auop, CPOperand in,
CPOperand out, String opcode, String istr) {
super(FEDType.AggregateUnary, auop, in, out, opcode, istr);
}
@@ -45,37 +44,28 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[2]);
- String aopcode = InstructionUtils.deriveAggregateOperatorOpcode(opcode);
- Types.CorrectionLocationType corrLoc = InstructionUtils
- .deriveAggregateOperatorCorrectionLocation(opcode);
AggregateUnaryOperator aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode);
- AggregateOperator aop = InstructionUtils.parseAggregateOperator(aopcode, corrLoc.toString());
- return new AggregateUnaryFEDInstruction(aggun, aop, in1, out, opcode, str);
+ if(InstructionUtils.getExecType(str) == ExecType.SPARK)
+ str = InstructionUtils.replaceOperand(str, 4, "-1");
+ return new AggregateUnaryFEDInstruction(aggun, in1, out, opcode, str);
}
@Override
public void processInstruction(ExecutionContext ec) {
- String output_name = output.getName();
- String opcode = getOpcode();
+ AggregateUnaryOperator aop = (AggregateUnaryOperator) _optr;
+ MatrixObject in = ec.getMatrixObject(input1);
+
+ //create federated commands for aggregation
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()});
+ FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
- AggregateUnaryOperator au_op = (AggregateUnaryOperator) _optr;
- MatrixObject matrixObject;
- if (input1.getDataType() == DataType.MATRIX &&
- (matrixObject = ec.getMatrixObject(input1.getName())).isFederated()) {
- MatrixBlock outMatrix = LibFederatedAgg.aggregateUnaryMatrix(matrixObject, au_op);
-
- if (output.getDataType() == DataType.SCALAR) {
- DoubleObject ret = new DoubleObject(outMatrix.getValue(0, 0));
- ec.setScalarOutput(output_name, ret);
- }
- else {
- ec.setMatrixOutput(output_name, outMatrix);
- ec.getMatrixObject(output_name).getDataCharacteristics()
- .setBlocksize(ConfigurationManager.getBlocksize());
- }
- }
- else {
- throw new DMLRuntimeException(opcode + " only supported on federated matrix.");
- }
+ //execute federated commands and cleanups
+ Future<FederatedResponse>[] tmp = in.getFedMapping().execute(fr1, fr2);
+ in.getFedMapping().cleanup(fr1.getID());
+ if( output.isScalar() )
+ ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp));
+ else
+ ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, tmp));
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index d770063..8fed7f7 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -22,27 +22,22 @@ package org.apache.sysds.runtime.instructions.fed;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.LibFederatedAppend;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.OffsetColumnIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
public class AppendFEDInstruction extends BinaryFEDInstruction {
- public enum FEDAppendType {
- CBIND, RBIND;
- public boolean isCBind() {
- return this == CBIND;
- }
- }
-
- protected final FEDAppendType _type;
+ protected boolean _cbind; //otherwise rbind
- protected AppendFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, FEDAppendType type,
- String opcode, String istr) {
+ protected AppendFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+ boolean cbind, String opcode, String istr) {
super(FEDType.Append, op, in1, in2, out, opcode, istr);
- _type = type;
+ _cbind = cbind;
}
public static AppendFEDInstruction parseInstruction(String str) {
@@ -55,36 +50,53 @@ public class AppendFEDInstruction extends BinaryFEDInstruction {
CPOperand out = new CPOperand(parts[parts.length - 2]);
boolean cbind = Boolean.parseBoolean(parts[parts.length - 1]);
- FEDAppendType type = cbind ? FEDAppendType.CBIND : FEDAppendType.RBIND;
-
- if (!opcode.equalsIgnoreCase("append") && !opcode.equalsIgnoreCase("remove")
- && !opcode.equalsIgnoreCase("galignedappend"))
- throw new DMLRuntimeException("Unknown opcode while parsing a AppendCPInstruction: " + str);
-
Operator op = new ReorgOperator(OffsetColumnIndex.getOffsetColumnIndexFnObject(-1));
- return new AppendFEDInstruction(op, in1, in2, out, type, opcode, str);
+ return new AppendFEDInstruction(op, in1, in2, out, cbind, opcode, str);
}
@Override
public void processInstruction(ExecutionContext ec) {
//get inputs
- MatrixObject matObject1 = ec.getMatrixObject(input1.getName());
- MatrixObject matObject2 = ec.getMatrixObject(input2.getName());
+ MatrixObject mo1 = ec.getMatrixObject(input1.getName());
+ MatrixObject mo2 = ec.getMatrixObject(input2.getName());
+ DataCharacteristics dc1 = mo1.getDataCharacteristics();
+ DataCharacteristics dc2 = mo1.getDataCharacteristics();
+
//check input dimensions
- if (_type == FEDAppendType.CBIND && matObject1.getNumRows() != matObject2.getNumRows()) {
+ if (_cbind && mo1.getNumRows() != mo2.getNumRows()) {
throw new DMLRuntimeException(
"Append-cbind is not possible for federated input matrices " + input1.getName() + " and "
- + input2.getName() + " with different number of rows: " + matObject1.getNumRows() + " vs "
- + matObject2.getNumRows());
+ + input2.getName() + " with different number of rows: " + mo1.getNumRows() + " vs "
+ + mo2.getNumRows());
}
- else if (_type == FEDAppendType.RBIND && matObject1.getNumColumns() != matObject2.getNumColumns()) {
+ else if (!_cbind && mo1.getNumColumns() != mo2.getNumColumns()) {
throw new DMLRuntimeException(
"Append-rbind is not possible for federated input matrices " + input1.getName() + " and "
- + input2.getName() + " with different number of columns: " + matObject1.getNumColumns()
- + " vs " + matObject2.getNumColumns());
+ + input2.getName() + " with different number of columns: " + mo1.getNumColumns()
+ + " vs " + mo2.getNumColumns());
+ }
+
+ if( mo1.isFederated() && _cbind ) {
+ FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
+ mo1.getFedMapping().execute(fr1, fr2);
+ //derive new fed mapping for output
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(dc1.getRows(), dc1.getCols()+dc2.getCols(),
+ dc1.getBlocksize(), dc1.getNonZeros()+dc2.getNonZeros());
+ out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID()));
+ }
+ else if( mo1.isFederated() && mo2.isFederated() && !_cbind ) {
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(dc1.getRows()+dc2.getRows(), dc1.getCols(),
+ dc1.getBlocksize(), dc1.getNonZeros()+dc2.getNonZeros());
+ long id = FederationUtils.getNextFedDataID();
+ out.setFedMapping(mo1.getFedMapping().copyWithNewID(id).rbind(dc1.getRows(), mo2.getFedMapping()));
+ }
+ else { //other combinations
+ throw new DMLRuntimeException("Federated AggregateBinary not supported with the "
+ + "following federated objects: "+mo1.isFederated()+" "+mo2.isFederated());
}
- // append MatrixObjects
- LibFederatedAppend.federateAppend(matObject1, matObject2,
- ec.getMatrixObject(output.getName()), _type.isCBind());
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
index 3c9db11..0e05ca8 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
@@ -19,25 +19,12 @@
package org.apache.sysds.runtime.instructions.fed;
-import org.apache.commons.lang3.tuple.ImmutablePair;
-import org.apache.commons.lang3.tuple.Pair;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.instructions.cp.ScalarObject;
-import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
-import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.Future;
public class BinaryMatrixScalarFEDInstruction extends BinaryFEDInstruction
{
@@ -48,47 +35,25 @@ public class BinaryMatrixScalarFEDInstruction extends BinaryFEDInstruction
@Override
public void processInstruction(ExecutionContext ec) {
- MatrixObject matrix = ec.getMatrixObject(input1.isMatrix() ? input1 : input2);
- ScalarObject scalar = ec.getScalarInput(input2.isScalar() ? input2 : input1);
-
- ScalarOperator sc_op = (ScalarOperator) _optr;
- sc_op = sc_op.setConstant(scalar.getDoubleValue());
-
- if (!matrix.isFederated())
- throw new DMLRuntimeException("Trying to execute federated operation on non federated matrix");
-
- MatrixBlock ret = new MatrixBlock((int)matrix.getNumRows(), (int) matrix.getNumColumns(), false);
- try {
- //Keep track on federated execution ond matrix shards
- List<Pair<FederatedRange, Future<FederatedResponse>>> idResponsePairs = new ArrayList<>();
-
- //execute federated operation
- for (Map.Entry<FederatedRange, FederatedData> entry : matrix.getFedMapping().entrySet()) {
- FederatedData shard = entry.getValue();
- if (!shard.isInitialized())
- throw new DMLRuntimeException("Not all FederatedData was initialized for federated matrix");
- Future<FederatedResponse> future = shard.executeFederatedOperation(
- new FederatedRequest(FederatedRequest.FedMethod.SCALAR, sc_op), true);
- idResponsePairs.add(new ImmutablePair<>(entry.getKey(), future));
- }
-
- for (Pair<FederatedRange, Future<FederatedResponse>> idResponsePair : idResponsePairs) {
- FederatedRange range = idResponsePair.getLeft();
- //wait for fed workers finishing their work
- FederatedResponse federatedResponse = idResponsePair.getRight().get();
-
- MatrixBlock shard = (MatrixBlock) federatedResponse.getData()[0];
- ret.copy(range.getBeginDimsInt()[0], range.getEndDimsInt()[0]-1,
- range.getBeginDimsInt()[1], range.getEndDimsInt()[1]-1, shard, false);
- }
- }
- catch (Exception e) {
- throw new DMLRuntimeException("Federated binary operation failed", e);
- }
-
- if(ret.getNumRows() != matrix.getNumRows() || ret.getNumColumns() != matrix.getNumColumns())
- throw new DMLRuntimeException("Federated binary operation returns invalid matrix dimension");
+ CPOperand matrix = input1.isMatrix() ? input1 : input2;
+ CPOperand scalar = input2.isScalar() ? input2 : input1;
+ MatrixObject mo = ec.getMatrixObject(matrix);
+
+ //execute federated matrix-scalar operation and cleanups
+ FederatedRequest fr1 = !scalar.isLiteral() ?
+ mo.getFedMapping().broadcast(ec.getScalarInput(scalar)) : null;
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{matrix, (fr1 != null)?scalar:null},
+ new long[]{mo.getFedMapping().getID(), (fr1 != null)?fr1.getID():-1});
+
+ mo.getFedMapping().execute((fr1!=null) ?
+ new FederatedRequest[]{fr1, fr2}: new FederatedRequest[]{fr2});
+ if( fr1 != null )
+ mo.getFedMapping().cleanup(fr1.getID());
- ec.setMatrixOutput(output.getName(), ret);
+ //derive new fed mapping for output
+ MatrixObject out = ec.getMatrixObject(output);
+ out.getDataCharacteristics().set(mo.getDataCharacteristics());
+ out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID()));
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index e4b7dac..8d050b3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
@@ -30,6 +31,8 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
@@ -103,7 +106,6 @@ public class InitFEDInstruction extends FEDInstruction {
for (int i = 0; i < addresses.getLength(); i++) {
Data addressData = addresses.getData().get(i);
if (addressData instanceof StringObject) {
-
// We split address into url/ip, the port and file path of file to read
String[] parsedValues = parseURL(((StringObject) addressData).getStringValue());
String host = parsedValues[0];
@@ -136,7 +138,6 @@ public class InitFEDInstruction extends FEDInstruction {
catch (UnknownHostException e) {
throw new DMLRuntimeException("federated host was unknown: " + host);
}
-
}
else {
throw new DMLRuntimeException("federated instruction only takes strings as addresses");
@@ -206,6 +207,7 @@ public class InitFEDInstruction extends FEDInstruction {
fedMapping.put(t.getLeft(), t.getRight());
}
List<Pair<FederatedData, Future<FederatedResponse>>> idResponses = new ArrayList<>();
+ long id = FederationUtils.getNextFedDataID();
for (Map.Entry<FederatedRange, FederatedData> entry : fedMapping.entrySet()) {
FederatedRange range = entry.getKey();
FederatedData value = entry.getValue();
@@ -217,7 +219,7 @@ public class InitFEDInstruction extends FEDInstruction {
dims[i] = endDims[i] - beginDims[i];
}
// TODO check if all matrices have the same DataType (currently only double is supported)
- idResponses.add(new ImmutablePair<>(value, value.initFederatedData()));
+ idResponses.add(new ImmutablePair<>(value, value.initFederatedData(id)));
}
}
try {
@@ -230,7 +232,8 @@ public class InitFEDInstruction extends FEDInstruction {
throw new DMLRuntimeException("Federation initialization failed", e);
}
output.getDataCharacteristics().setNonZeros(output.getNumColumns() * output.getNumRows());
- output.setFedMapping(fedMapping);
+ output.getDataCharacteristics().setBlocksize(ConfigurationManager.getBlocksize());
+ output.setFedMapping(new FederationMap(id, fedMapping));
}
public void federateFrame(FrameObject output, List<Pair<FederatedRange, FederatedData>> workers) {
@@ -242,6 +245,7 @@ public class InitFEDInstruction extends FEDInstruction {
// on the distributed workers. We need the FederatedData, the starting column of the sub frame (for the schema)
// and the future for the response
List<Pair<FederatedData, Pair<Integer, Future<FederatedResponse>>>> idResponses = new ArrayList<>();
+ long id = FederationUtils.getNextFedDataID();
for (Map.Entry<FederatedRange, FederatedData> entry : fedMapping.entrySet()) {
FederatedRange range = entry.getKey();
FederatedData value = entry.getValue();
@@ -252,7 +256,7 @@ public class InitFEDInstruction extends FEDInstruction {
for (int i = 0; i < dims.length; i++) {
dims[i] = endDims[i] - beginDims[i];
}
- idResponses.add(new ImmutablePair<>(value, new ImmutablePair<>((int) beginDims[1], value.initFederatedData())));
+ idResponses.add(new ImmutablePair<>(value, new ImmutablePair<>((int) beginDims[1], value.initFederatedData(id))));
}
}
// columns are definitely in int range, because we throw an DMLRuntime Exception in `processInstruction` else
@@ -271,7 +275,7 @@ public class InitFEDInstruction extends FEDInstruction {
}
output.getDataCharacteristics().setNonZeros(output.getNumColumns() * output.getNumRows());
output.setSchema(schema);
- output.setFedMapping(fedMapping);
+ output.setFedMapping(new FederationMap(id, fedMapping));
}
private static void handleFedFrameResponse(Types.ValueType[] schema, FederatedData federatedData,
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
index 6d0ca53..a1faae0 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/AggregateUnaryOperator.java
@@ -26,6 +26,8 @@ import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Or;
import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.functionobjects.ReduceCol;
+import org.apache.sysds.runtime.functionobjects.ReduceRow;
public class AggregateUnaryOperator extends Operator
@@ -58,4 +60,12 @@ public class AggregateUnaryOperator extends Operator
public int getNumThreads(){
return k;
}
+
+ public boolean isRowAggregate() {
+ return indexFn instanceof ReduceCol;
+ }
+
+ public boolean isColAggregate() {
+ return indexFn instanceof ReduceRow;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index ad0b6d7..77149d1 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -24,19 +24,12 @@ import java.util.Arrays;
import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
-import java.util.Map;
import java.util.Set;
-import java.util.concurrent.Future;
import org.apache.commons.lang.ArrayUtils;
-import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
@@ -801,23 +794,4 @@ public class UtilFunctions {
break;
}
}
-
- public static List<org.apache.commons.lang3.tuple.Pair<FederatedRange, Future<FederatedResponse>>> requestFederatedData(
- Map<FederatedRange, FederatedData> fedMapping) {
- List<org.apache.commons.lang3.tuple.Pair<FederatedRange, Future<FederatedResponse>>> readResponses = new ArrayList<>();
- for(Map.Entry<FederatedRange, FederatedData> entry : fedMapping.entrySet()) {
- FederatedRange range = entry.getKey();
- FederatedData fd = entry.getValue();
-
- if(fd.isInitialized()) {
- FederatedRequest request = new FederatedRequest(FederatedRequest.FedMethod.TRANSFER);
- Future<FederatedResponse> readResponse = fd.executeFederatedOperation(request, true);
- readResponses.add(new ImmutablePair<>(range, readResponse));
- }
- else {
- throw new DMLRuntimeException("Federated matrix read only supported on initialized FederatedData");
- }
- }
- return readResponses;
- }
}
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 8c1973b..bf62c34 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -1200,7 +1200,7 @@ public abstract class AutomatedTestBase {
if(exceptionExpected)
fail("expected exception which has not been raised: " + expectedException);
}
- catch(Exception e) {
+ catch(Exception | Error e) {
if( !outputBuffering )
e.printStackTrace();
if(errMessage != null && !errMessage.equals("")) {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
index 340234b..b4ce148 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
@@ -117,12 +117,10 @@ public class FederatedConstructionTest extends AutomatedTestBase {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
- Thread t;
-
String HOME = SCRIPT_DIR + TEST_DIR;
int port = getRandomAvailablePort();
- t = startLocalFedWorker(port);
+ Thread t = startLocalFedWorker(port);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedMultiplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedMultiplyTest.java
index 67a74be..7968dd7 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedMultiplyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedMultiplyTest.java
@@ -77,8 +77,6 @@ public class FederatedMultiplyTest extends AutomatedTestBase {
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}
- Thread t1, t2;
-
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
@@ -98,8 +96,8 @@ public class FederatedMultiplyTest extends AutomatedTestBase {
int port1 = getRandomAvailablePort();
int port2 = getRandomAvailablePort();
- t1 = startLocalFedWorker(port1);
- t2 = startLocalFedWorker(port2);
+ Thread t1 = startLocalFedWorker(port1);
+ Thread t2 = startLocalFedWorker(port2);
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedRCBindTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedRCBindTest.java
index 93e1feb..81c0b00 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedRCBindTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedRCBindTest.java
@@ -48,7 +48,8 @@ public class FederatedRCBindTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
- return Arrays.asList(new Object[][] {{1, 1000}, {10, 100}, {100, 10}, {1000, 1}, {10, 2000}, {2000, 10}});
+ //TODO add tests and support of aligned blocksized (which is however a special case)
+ return Arrays.asList(new Object[][] {{1, 1001}, {10, 100}, {100, 10}, {1001, 1}, {10, 2001}, {2001, 10}});
}
@Override
@@ -71,8 +72,6 @@ public class FederatedRCBindTest extends AutomatedTestBase {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
- Thread t;
-
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
@@ -80,7 +79,7 @@ public class FederatedRCBindTest extends AutomatedTestBase {
writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows, cols, blocksize, rows * cols));
int port = getRandomAvailablePort();
- t = startLocalFedWorker(port);
+ Thread t = startLocalFedWorker(port);
// we need the reference file to not be written to hdfs, so we get the correct format
rtplatform = Types.ExecMode.SINGLE_NODE;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedSumTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedSumTest.java
index 3118f01..69a743c 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedSumTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedSumTest.java
@@ -72,15 +72,13 @@ public class FederatedSumTest extends AutomatedTestBase {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
- Thread t;
-
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
double[][] A = getRandomMatrix(rows / 2, cols, -10, 10, 1, 1);
writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows / 2, cols, blocksize, (rows / 2) * cols));
int port = getRandomAvailablePort();
- t = startLocalFedWorker(port);
+ Thread t = startLocalFedWorker(port);
// we need the reference file to not be written to hdfs, so we get the correct format
rtplatform = Types.ExecMode.SINGLE_NODE;
diff --git a/src/test/scripts/functions/federated/FederatedSumTest.dml b/src/test/scripts/functions/federated/FederatedSumTest.dml
index 8a8efb3..37a19f6 100644
--- a/src/test/scripts/functions/federated/FederatedSumTest.dml
+++ b/src/test/scripts/functions/federated/FederatedSumTest.dml
@@ -23,6 +23,7 @@ A = federated(addresses=list($in, $in), ranges=list(list(0, 0), list($rows / 2,
s = sum(A)
r = rowSums(A)
c = colSums(A)
+
write(s, $out_S)
write(r, $out_R)
write(c, $out_C)
diff --git a/src/test/scripts/functions/federated/matrix_scalar/FederatedMatrixAdditionScalar.dml b/src/test/scripts/functions/federated/matrix_scalar/FederatedMatrixAdditionScalar.dml
index 9468448..6f1dac3 100644
--- a/src/test/scripts/functions/federated/matrix_scalar/FederatedMatrixAdditionScalar.dml
+++ b/src/test/scripts/functions/federated/matrix_scalar/FederatedMatrixAdditionScalar.dml
@@ -25,4 +25,5 @@
M = federated(addresses=list($in), ranges=list(list(0, 0), list($rows, $cols)))
S = $scalar
R = M + S
+
write(R, $out)