You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2020/11/17 22:19:25 UTC
[systemds] branch master updated: [SYSTEMDS-2730] Modified fed
removeEmpty
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 2a8cb78 [SYSTEMDS-2730] Modified fed removeEmpty
2a8cb78 is described below
commit 2a8cb78827daed00fe016f6af22ab24f154be40c
Author: Olga <ov...@gmail.com>
AuthorDate: Tue Nov 17 21:21:55 2020 +0100
[SYSTEMDS-2730] Modified fed removeEmpty
This commits change the remove empty federated command, to
among other things improve the split function performance.
Closes #1109
---
scripts/builtin/split.dml | 13 +-
.../controlprogram/federated/FederationMap.java | 204 +++++++++---------
.../fed/ParameterizedBuiltinFEDInstruction.java | 234 +++++++++------------
.../primitives/FederatedRemoveEmptyTest.java | 10 +-
.../federated/primitives/FederatedSplitTest.java | 8 +-
.../functions/federated/FederatedSplitTest.dml | 3 +-
.../federated/FederatedSplitTestReference.dml | 2 +-
7 files changed, 217 insertions(+), 257 deletions(-)
diff --git a/scripts/builtin/split.dml b/scripts/builtin/split.dml
index 5e6f1c5..c5c1066 100644
--- a/scripts/builtin/split.dml
+++ b/scripts/builtin/split.dml
@@ -53,12 +53,13 @@ m_split = function(Matrix[Double] X, Matrix[Double] Y, Double f=0.7, Boolean con
}
# sampled train/test splits
else {
+ # create random select vector according to f and then
+ # extract tuples via permutation (selection) matrix multiply
+ # or directly via removeEmpty by selection vector
I = rand(rows=nrow(X), cols=1, seed=seed) <= f;
- P1 = removeEmpty(target=diag(I), margin="rows", select=I);
- P2 = removeEmpty(target=diag(I==0), margin="rows", select=I==0);
- Xtrain = P1 %*% X;
- Ytrain = P1 %*% Y;
- Xtest = P2 %*% X;
- Ytest = P2 %*% Y;
+ Xtrain = removeEmpty(target=X, margin="rows", select=I);
+ Ytrain = removeEmpty(target=Y, margin="rows", select=I);
+ Xtest = removeEmpty(target=X, margin="rows", select=(I==0));
+ Ytest = removeEmpty(target=Y, margin="rows", select=(I==0));
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 037ce8c..2ce3cb7 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -43,12 +43,11 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.IndexRange;
-public class FederationMap
-{
+public class FederationMap {
public enum FType {
- ROW, //row partitioned, groups of rows
- COL, //column partitioned, groups of columns
- FULL, // Meaning both Row and Column indicating a single federated location and a full matrix
+ ROW, // row partitioned, groups of rows
+ COL, // column partitioned, groups of columns
+ FULL, // Meaning both Row and Column indicating a single federated location and a full matrix
OTHER;
public boolean isRowPartitioned() {
@@ -56,11 +55,11 @@ public class FederationMap
}
public boolean isColPartitioned() {
- return this == ROW || this == FULL;
+ return this == COL || this == FULL;
}
- public boolean isType(FType t){
- switch (t) {
+ public boolean isType(FType t) {
+ switch(t) {
case ROW:
return isRowPartitioned();
case COL:
@@ -72,161 +71,159 @@ public class FederationMap
}
}
}
-
+
private long _ID = -1;
private final Map<FederatedRange, FederatedData> _fedMap;
private FType _type;
-
+
public FederationMap(Map<FederatedRange, FederatedData> fedMap) {
this(-1, fedMap);
}
-
+
public FederationMap(long ID, Map<FederatedRange, FederatedData> fedMap) {
this(ID, fedMap, FType.OTHER);
}
-
+
public FederationMap(long ID, Map<FederatedRange, FederatedData> fedMap, FType type) {
_ID = ID;
_fedMap = fedMap;
_type = type;
}
-
+
public long getID() {
return _ID;
}
-
+
public FType getType() {
return _type;
}
-
+
public boolean isInitialized() {
return _ID >= 0;
}
-
+
public void setType(FType type) {
_type = type;
}
-
+
public int getSize() {
return _fedMap.size();
}
-
+
public FederatedRange[] getFederatedRanges() {
return _fedMap.keySet().toArray(new FederatedRange[0]);
}
- public Map<FederatedRange, FederatedData> getFedMapping(){
+ public Map<FederatedRange, FederatedData> getFedMapping() {
return _fedMap;
}
-
+
public FederatedRequest broadcast(CacheableData<?> data) {
- //prepare single request for all federated data
+ // prepare single request for all federated data
long id = FederationUtils.getNextFedDataID();
CacheBlock cb = data.acquireReadAndRelease();
return new FederatedRequest(RequestType.PUT_VAR, id, cb);
}
-
+
public FederatedRequest broadcast(ScalarObject scalar) {
- //prepare single request for all federated data
+ // prepare single request for all federated data
long id = FederationUtils.getNextFedDataID();
return new FederatedRequest(RequestType.PUT_VAR, id, scalar);
}
-
+
/**
- * Creates separate slices of an input data object according
- * to the index ranges of federated data. Theses slices are then
- * wrapped in separate federated requests for broadcasting.
+ * Creates separate slices of an input data object according to the index ranges of federated data. Theses slices
+ * are then wrapped in separate federated requests for broadcasting.
*
- * @param data input data object (matrix, tensor, frame)
- * @param transposed false: slice according to federated data,
- * true: slice according to transposed federated data
+ * @param data input data object (matrix, tensor, frame)
+ * @param transposed false: slice according to federated data, true: slice according to transposed federated data
* @return array of federated requests corresponding to federated data
*/
public FederatedRequest[] broadcastSliced(CacheableData<?> data, boolean transposed) {
- //prepare broadcast id and pin input
+ // prepare broadcast id and pin input
long id = FederationUtils.getNextFedDataID();
CacheBlock cb = data.acquireReadAndRelease();
-
- //prepare indexing ranges
+
+ // prepare indexing ranges
int[][] ix = new int[_fedMap.size()][];
int pos = 0;
for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
- int rl = transposed ? 0 : e.getKey().getBeginDimsInt()[0];
- int ru = transposed ? cb.getNumRows()-1 : e.getKey().getEndDimsInt()[0]-1;
- int cl = transposed ? e.getKey().getBeginDimsInt()[0] : 0;
- int cu = transposed ? e.getKey().getEndDimsInt()[0]-1 : cb.getNumColumns()-1;
+ int rl, ru, cl, cu;
+ // TODO Handle different cases than ROW aligned Matrices.
+ rl = transposed ? 0 : e.getKey().getBeginDimsInt()[0];
+ ru = transposed ? cb.getNumRows() - 1 : e.getKey().getEndDimsInt()[0] - 1;
+ cl = transposed ? e.getKey().getBeginDimsInt()[0] : 0;
+ cu = transposed ? e.getKey().getEndDimsInt()[0] - 1 : cb.getNumColumns() - 1;
ix[pos++] = new int[] {rl, ru, cl, cu};
}
-
- //multi-threaded block slicing and federation request creation
+
+ // multi-threaded block slicing and federation request creation
FederatedRequest[] ret = new FederatedRequest[ix.length];
- Arrays.parallelSetAll(ret, i ->
- new FederatedRequest(RequestType.PUT_VAR, id,
- cb.slice(ix[i][0], ix[i][1], ix[i][2], ix[i][3], new MatrixBlock())));
+ Arrays.parallelSetAll(ret,
+ i -> new FederatedRequest(RequestType.PUT_VAR, id,
+ cb.slice(ix[i][0], ix[i][1], ix[i][2], ix[i][3], new MatrixBlock())));
return ret;
}
-
+
public boolean isAligned(FederationMap that, boolean transposed) {
- //determines if the two federated data are aligned row/column partitions
- //at the same federated site (which allows for purely federated operation)
+ // determines if the two federated data are aligned row/column partitions
+ // at the same federated site (which allows for purely federated operation)
boolean ret = true;
for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
- FederatedRange range = !transposed ? e.getKey() :
- new FederatedRange(e.getKey()).transpose();
+ FederatedRange range = !transposed ? e.getKey() : new FederatedRange(e.getKey()).transpose();
FederatedData dat2 = that._fedMap.get(range);
ret &= e.getValue().equalAddress(dat2);
}
return ret;
}
-
+
public Future<FederatedResponse>[] execute(long tid, FederatedRequest... fr) {
return execute(tid, false, fr);
}
-
+
public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequest... fr) {
return execute(tid, wait, null, fr);
}
-
+
public Future<FederatedResponse>[] execute(long tid, FederatedRequest[] frSlices, FederatedRequest... fr) {
return execute(tid, false, frSlices, fr);
}
-
+
@SuppressWarnings("unchecked")
- public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequest[] frSlices, FederatedRequest... fr) {
+ public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequest[] frSlices,
+ FederatedRequest... fr) {
// executes step1[] - step 2 - ... step4 (only first step federated-data-specific)
setThreadID(tid, frSlices, fr);
- List<Future<FederatedResponse>> ret = new ArrayList<>();
+ List<Future<FederatedResponse>> ret = new ArrayList<>();
int pos = 0;
for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
- ret.add(e.getValue().executeFederatedOperation(
- (frSlices!=null) ? addAll(frSlices[pos++], fr) : fr));
-
- // prepare results (future federated responses), with optional wait to ensure the
+ ret.add(e.getValue().executeFederatedOperation((frSlices != null) ? addAll(frSlices[pos++], fr) : fr));
+
+ // prepare results (future federated responses), with optional wait to ensure the
// order of requests without data dependencies (e.g., cleanup RPCs)
- if( wait )
+ if(wait)
FederationUtils.waitFor(ret);
return ret.toArray(new Future[0]);
}
-
+
public List<Pair<FederatedRange, Future<FederatedResponse>>> requestFederatedData() {
- if( !isInitialized() )
+ if(!isInitialized())
throw new DMLRuntimeException("Federated matrix read only supported on initialized FederatedData");
-
+
List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = new ArrayList<>();
FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, _ID);
for(Map.Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
- readResponses.add(new ImmutablePair<>(e.getKey(),
- e.getValue().executeFederatedOperation(request)));
+ readResponses.add(new ImmutablePair<>(e.getKey(), e.getValue().executeFederatedOperation(request)));
return readResponses;
}
-
+
public FederatedRequest cleanup(long tid, long... id) {
FederatedRequest request = new FederatedRequest(RequestType.EXEC_INST, -1,
VariableCPInstruction.prepareRemoveInstruction(id).toString());
request.setTID(tid);
return request;
}
-
+
public void execCleanup(long tid, long... id) {
FederatedRequest request = new FederatedRequest(RequestType.EXEC_INST, -1,
VariableCPInstruction.prepareRemoveInstruction(id).toString());
@@ -234,16 +231,17 @@ public class FederationMap
List<Future<FederatedResponse>> tmp = new ArrayList<>();
for(FederatedData fd : _fedMap.values())
tmp.add(fd.executeFederatedOperation(request));
- //wait to avoid interference w/ following requests
+ // wait to avoid interference w/ following requests
FederationUtils.waitFor(tmp);
}
-
+
private static FederatedRequest[] addAll(FederatedRequest a, FederatedRequest[] b) {
FederatedRequest[] ret = new FederatedRequest[b.length + 1];
- ret[0] = a; System.arraycopy(b, 0, ret, 1, b.length);
+ ret[0] = a;
+ System.arraycopy(b, 0, ret, 1, b.length);
return ret;
}
-
+
public FederationMap identCopy(long tid, long id) {
Future<FederatedResponse>[] copyInstr = execute(tid,
new FederatedRequest(RequestType.EXEC_INST, _ID,
@@ -262,25 +260,25 @@ public class FederationMap
copyFederationMap._type = _type;
return copyFederationMap;
}
-
+
public FederationMap copyWithNewID() {
return copyWithNewID(FederationUtils.getNextFedDataID());
}
-
+
public FederationMap copyWithNewID(long id) {
Map<FederatedRange, FederatedData> map = new TreeMap<>();
- //TODO handling of file path, but no danger as never written
- for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() ) {
+ // TODO handling of file path, but no danger as never written
+ for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
if(e.getKey().getSize() != 0)
map.put(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id));
}
return new FederationMap(id, map, _type);
}
-
+
public FederationMap copyWithNewID(long id, long clen) {
Map<FederatedRange, FederatedData> map = new TreeMap<>();
- //TODO handling of file path, but no danger as never written
- for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() )
+ // TODO handling of file path, but no danger as never written
+ for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
map.put(new FederatedRange(e.getKey(), clen), e.getValue().copyWithNewID(id));
return new FederationMap(id, map, _type);
}
@@ -295,24 +293,28 @@ public class FederationMap
public FederationMap transpose() {
Map<FederatedRange, FederatedData> tmp = new TreeMap<>(_fedMap);
_fedMap.clear();
- for( Entry<FederatedRange, FederatedData> e : tmp.entrySet() ) {
+ for(Entry<FederatedRange, FederatedData> e : tmp.entrySet()) {
_fedMap.put(new FederatedRange(e.getKey()).transpose(), e.getValue().copyWithNewID(_ID));
}
- //derive output type
+ // derive output type
switch(_type) {
- case FULL: _type = FType.FULL; break;
- case ROW: _type = FType.COL; break;
- case COL: _type = FType.ROW; break;
- default: _type = FType.OTHER;
+ case FULL:
+ _type = FType.FULL;
+ break;
+ case ROW:
+ _type = FType.COL;
+ break;
+ case COL:
+ _type = FType.ROW;
+ break;
+ default:
+ _type = FType.OTHER;
}
return this;
}
-
public long getMaxIndexInRange(int dim) {
- return _fedMap.keySet().stream()
- .mapToLong(range -> range.getEndDims()[dim]).max()
- .orElse(-1L);
+ return _fedMap.keySet().stream().mapToLong(range -> range.getEndDims()[dim]).max().orElse(-1L);
}
/**
@@ -352,27 +354,27 @@ public class FederationMap
fedMapCopy._ID = newVarID;
return fedMapCopy;
}
-
+
public FederationMap filter(IndexRange ixrange) {
- FederationMap ret = this.clone(); //same ID
-
+ FederationMap ret = this.clone(); // same ID
+
Iterator<Entry<FederatedRange, FederatedData>> iter = ret._fedMap.entrySet().iterator();
- while( iter.hasNext() ) {
+ while(iter.hasNext()) {
Entry<FederatedRange, FederatedData> e = iter.next();
FederatedRange range = e.getKey();
- long rs = range.getBeginDims()[0], re = range.getEndDims()[0],
- cs = range.getBeginDims()[1], ce = range.getEndDims()[1];
- boolean overlap = ((ixrange.colStart <= ce) && (ixrange.colEnd >= cs)
- && (ixrange.rowStart <= re) && (ixrange.rowEnd >= rs));
- if( !overlap )
+ long rs = range.getBeginDims()[0], re = range.getEndDims()[0], cs = range.getBeginDims()[1],
+ ce = range.getEndDims()[1];
+ boolean overlap = ((ixrange.colStart <= ce) && (ixrange.colEnd >= cs) && (ixrange.rowStart <= re) &&
+ (ixrange.rowEnd >= rs));
+ if(!overlap)
iter.remove();
}
return ret;
}
-
+
private static void setThreadID(long tid, FederatedRequest[]... frsets) {
- for( FederatedRequest[] frset : frsets )
- if( frset != null )
+ for(FederatedRequest[] frset : frsets)
+ if(frset != null)
Arrays.stream(frset).forEach(fr -> fr.setTID(tid));
}
@@ -399,14 +401,14 @@ public class FederationMap
}
@Override
- public String toString(){
+ public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("Fed Map: " + _type);
sb.append("\t ID:" + _ID);
- sb.append("\n"+ _fedMap);
+ sb.append("\n" + _fedMap);
return sb.toString();
}
-
+
@Override
public FederationMap clone() {
return copyWithNewID(getID());
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index c50671e..6588909 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -19,7 +19,6 @@
package org.apache.sysds.runtime.instructions.fed;
-import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -136,9 +135,8 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
out.getDataCharacteristics().set(mo.getDataCharacteristics());
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
}
- else if(opcode.equals("rmempty")) {
+ else if(opcode.equals("rmempty"))
rmempty(ec);
- }
else if(opcode.equalsIgnoreCase("transformdecode"))
transformDecode(ec);
else if(opcode.equalsIgnoreCase("transformapply"))
@@ -149,32 +147,33 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
}
private void rmempty(ExecutionContext ec) {
+ String margin = params.get("margin");
+ if( !(margin.equals("rows") || margin.equals("cols")) )
+ throw new DMLRuntimeException("Unspupported margin identifier '"+margin+"'.");
+
MatrixObject mo = (MatrixObject) getTarget(ec);
+ MatrixObject select = params.containsKey("select") ? ec.getMatrixObject(params.get("select")) : null;
MatrixObject out = ec.getMatrixObject(output);
- Map<FederatedRange, int[]> dcs;
- if((instString.contains("margin=rows") && mo.isFederated(FederationMap.FType.ROW)) ||
- (instString.contains("margin=cols") && mo.isFederated(FederationMap.FType.COL))) {
- FederatedRequest fr1 = FederationUtils.callInstruction(instString,
- output,
- new CPOperand[] {getTargetOperand()},
- new long[] {mo.getFedMapping().getID()});
- mo.getFedMapping().execute(getTID(), true, fr1);
- out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
- // new ranges
- dcs = new HashMap<>();
- out.getFedMapping().forEachParallel((range, data) -> {
+ boolean marginRow = params.get("margin").equals("rows");
+ boolean k = ((marginRow && mo.getFedMapping().getType().isColPartitioned()) ||
+ (!marginRow && mo.getFedMapping().getType().isRowPartitioned()));
+
+ MatrixBlock s = new MatrixBlock();
+ if(select == null && k) {
+ List<MatrixBlock> colSums = new ArrayList<>();
+ mo.getFedMapping().forEachParallel((range, data) -> {
try {
FederatedResponse response = data
.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new GetDataCharacteristics(data.getVarID())))
+ new GetVector(data.getVarID(), margin.equals("rows"))))
.get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
- int[] subRangeCharacteristics = (int[]) response.getData()[0];
- synchronized(dcs) {
- dcs.put(range, subRangeCharacteristics);
+ MatrixBlock vector = (MatrixBlock) response.getData()[0];
+ synchronized(colSums) {
+ colSums.add(vector);
}
}
catch(Exception e) {
@@ -182,53 +181,75 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
}
return null;
});
+ // find empty in matrix
+ BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
+ BinaryOperator greater = InstructionUtils.parseBinaryOperator(">");
+ s = colSums.get(0);
+ for(int i = 1; i < colSums.size(); i++)
+ s = s.binaryOperationsInPlace(plus, colSums.get(i));
+ s = s.binaryOperationsInPlace(greater, new MatrixBlock(s.getNumRows(), s.getNumColumns(), 0.0));
+ select = ExecutionContext.createMatrixObject(s);
+
+ long varID = FederationUtils.getNextFedDataID();
+ ec.setVariable(String.valueOf(varID), select);
+ params.put("select", String.valueOf(varID));
+ // construct new string
+ String[] oldString = InstructionUtils.getInstructionParts(instString);
+ String[] newString = new String[oldString.length+1];
+ newString[2] = "select="+varID;
+ System.arraycopy(oldString, 0, newString, 0,2);
+ System.arraycopy(oldString,2, newString, 3, newString.length-3);
+ instString = instString.replace(InstructionUtils.concatOperands(oldString), InstructionUtils.concatOperands(newString));
}
- else {
- Map.Entry<FederationMap, Map<FederatedRange, int[]>> entry = rmemptyC(ec, mo);
- out.setFedMapping(entry.getKey());
- dcs = entry.getValue();
- }
- out.getDataCharacteristics().set(mo.getDataCharacteristics());
- for(int i = 0; i < mo.getFedMapping().getFederatedRanges().length; i++) {
- int[] newRange = dcs.get(out.getFedMapping().getFederatedRanges()[i]);
-
- out.getFedMapping().getFederatedRanges()[i].setBeginDim(0,
- (out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] == 0 ||
- i == 0) ? 0 : out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[0]);
-
- out.getFedMapping().getFederatedRanges()[i].setEndDim(0,
- out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
-
- out.getFedMapping().getFederatedRanges()[i].setBeginDim(1,
- (out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] == 0 ||
- i == 0) ? 0 : out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[1]);
- out.getFedMapping().getFederatedRanges()[i].setEndDim(1,
- out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
+ if (select == null) {
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[] {getTargetOperand()},
+ new long[] {mo.getFedMapping().getID()});
+ mo.getFedMapping().execute(getTID(), true, fr1);
+ out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
}
+ else if (!k) {
+ //construct commands: broadcast , fed rmempty, clean broadcast
+ FederatedRequest[] fr1 = mo.getFedMapping().broadcastSliced(select, !marginRow);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString,
+ output,
+ new CPOperand[] {getTargetOperand(), new CPOperand(params.get("select"), ValueType.FP64, DataType.MATRIX)},
+ new long[] {mo.getFedMapping().getID(), fr1[0].getID()});
+ FederatedRequest fr3 = mo.getFedMapping().cleanup(getTID(), fr1[0].getID());
+
+ //execute federated operations and set output
+ mo.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+ out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID()));
+ } else {
+ //construct commands: broadcast , fed rmempty, clean broadcast
+ FederatedRequest fr1 = mo.getFedMapping().broadcast(select);
+ FederatedRequest fr2 = FederationUtils.callInstruction(instString,
+ output,
+ new CPOperand[] {getTargetOperand(), new CPOperand(params.get("select"), ValueType.FP64, DataType.MATRIX)},
+ new long[] {mo.getFedMapping().getID(), fr1.getID()});
+ FederatedRequest fr3 = mo.getFedMapping().cleanup(getTID(), fr1.getID());
- out.getDataCharacteristics().set(out.getFedMapping().getMaxIndexInRange(0),
- out.getFedMapping().getMaxIndexInRange(1),
- (int) mo.getBlocksize());
- }
-
- private Map.Entry<FederationMap, Map<FederatedRange, int[]>> rmemptyC(ExecutionContext ec, MatrixObject mo) {
- boolean marginRow = instString.contains("margin=rows");
+ //execute federated operations and set output
+ mo.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+ out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID()));
+ }
- // find empty in ranges
- List<MatrixBlock> colSums = new ArrayList<>();
- mo.getFedMapping().forEachParallel((range, data) -> {
+ // new ranges
+ Map<FederatedRange, int[]> dcs = new HashMap<>();
+ Map<FederatedRange, int[]> finalDcs1 = dcs;
+ out.getFedMapping().forEachParallel((range, data) -> {
try {
FederatedResponse response = data
.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new GetVector(data.getVarID(), marginRow)))
+ new GetDataCharacteristics(data.getVarID())))
.get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
- MatrixBlock vector = (MatrixBlock) response.getData()[0];
- synchronized(colSums) {
- colSums.add(vector);
+ int[] subRangeCharacteristics = (int[]) response.getData()[0];
+ synchronized(finalDcs1) {
+ finalDcs1.put(range, subRangeCharacteristics);
}
}
catch(Exception e) {
@@ -236,46 +257,28 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
}
return null;
});
+ dcs = finalDcs1;
+ out.getDataCharacteristics().set(mo.getDataCharacteristics());
+ for(int i = 0; i < mo.getFedMapping().getFederatedRanges().length; i++) {
+ int[] newRange = dcs.get(out.getFedMapping().getFederatedRanges()[i]);
- // find empty in matrix
- BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
- BinaryOperator greater = InstructionUtils.parseBinaryOperator(">");
- MatrixBlock tmp1 = colSums.get(0);
- for(int i = 1; i < colSums.size(); i++)
- tmp1 = tmp1.binaryOperationsInPlace(plus, colSums.get(i));
- tmp1 = tmp1.binaryOperationsInPlace(greater, new MatrixBlock(tmp1.getNumRows(), tmp1.getNumColumns(), 0.0));
+ out.getFedMapping().getFederatedRanges()[i].setBeginDim(0,
+ (out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] == 0 ||
+ i == 0) ? 0 : out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[0]);
- // remove empty from matrix
- Map<FederatedRange, int[]> dcs = new HashMap<>();
- long varID = FederationUtils.getNextFedDataID();
- MatrixBlock finalTmp = new MatrixBlock(tmp1);
- FederationMap resMapping;
- if(tmp1.sum() == (marginRow ? tmp1.getNumColumns() : tmp1.getNumRows())) {
- resMapping = mo.getFedMapping();
- }
- else {
- resMapping = mo.getFedMapping().mapParallel(varID, (range, data) -> {
- try {
- FederatedResponse response = data
- .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new ParameterizedBuiltinFEDInstruction.RemoveEmpty(data.getVarID(), varID, finalTmp,
- params.containsKey("select") ? ec.getMatrixInput(params.get("select")) : null,
- Boolean.parseBoolean(params.get("empty.return").toLowerCase()), marginRow)))
- .get();
- if(!response.isSuccessful())
- response.throwExceptionFromResponse();
- int[] subRangeCharacteristics = (int[]) response.getData()[0];
- synchronized(dcs) {
- dcs.put(range, subRangeCharacteristics);
- }
- }
- catch(Exception e) {
- throw new DMLRuntimeException(e);
- }
- return null;
- });
+ out.getFedMapping().getFederatedRanges()[i].setEndDim(0,
+ out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
+
+ out.getFedMapping().getFederatedRanges()[i].setBeginDim(1,
+ (out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] == 0 ||
+ i == 0) ? 0 : out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[1]);
+
+ out.getFedMapping().getFederatedRanges()[i].setEndDim(1,
+ out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
}
- return new AbstractMap.SimpleEntry<>(resMapping, dcs);
+
+ out.getDataCharacteristics().set(out.getFedMapping().getMaxIndexInRange(0),
+ out.getFedMapping().getMaxIndexInRange(1), (int) mo.getBlocksize());
}
private void transformDecode(ExecutionContext ec) {
@@ -506,52 +509,9 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
@Override
public FederatedResponse execute(ExecutionContext ec, Data... data) {
MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
- return new FederatedResponse(ResponseType.SUCCESS, new int[] {mb.getNumRows(), mb.getNumColumns()});
- }
- }
-
- private static class RemoveEmpty extends FederatedUDF {
-
- private static final long serialVersionUID = 12341521331L;
- private final MatrixBlock _vector;
- private final long _outputID;
- private MatrixBlock _select;
- private boolean _emptyReturn;
- private final boolean _marginRow;
-
- public RemoveEmpty(long varID, long outputID, MatrixBlock vector, MatrixBlock select, boolean emptyReturn,
- boolean marginRow) {
- super(new long[] {varID});
- _vector = vector;
- _outputID = outputID;
- _select = select;
- _emptyReturn = emptyReturn;
- _marginRow = marginRow;
- }
-
- @Override
- public FederatedResponse execute(ExecutionContext ec, Data... data) {
- MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
-
- BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
- BinaryOperator minus = InstructionUtils.parseBinaryOperator("-");
-
- mb = mb.binaryOperationsInPlace(plus, new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), 1.0));
- for(int i = 0; i < mb.getNumRows(); i++)
- for(int j = 0; j < mb.getNumColumns(); j++)
- if(_marginRow)
- mb.setValue(i, j, _vector.getValue(i, 0) * mb.getValue(i, j));
- else
- mb.setValue(i, j, _vector.getValue(0, j) * mb.getValue(i, j));
-
- MatrixBlock res = mb.removeEmptyOperations(new MatrixBlock(), _marginRow, _emptyReturn, _select);
- res = res.binaryOperationsInPlace(minus, new MatrixBlock(res.getNumRows(), res.getNumColumns(), 1.0));
-
- MatrixObject mout = ExecutionContext.createMatrixObject(res);
- ec.setVariable(String.valueOf(_outputID), mout);
-
- return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
- new int[] {res.getNumRows(), res.getNumColumns()});
+ int r = mb.getDenseBlockValues() != null ? mb.getNumRows() : 0;
+ int c = mb.getDenseBlockValues() != null ? mb.getNumColumns(): 0;
+ return new FederatedResponse(ResponseType.SUCCESS, new int[] {r, c});
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
index a629270..10a6711 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
@@ -55,7 +55,10 @@ public class FederatedRemoveEmptyTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
- return Arrays.asList(new Object[][] {{20, 10, true}, {20, 12, false}});
+ return Arrays.asList(new Object[][] {
+ {20, 12, true},
+ {20, 12, false}
+ });
}
@Override
@@ -94,11 +97,6 @@ public class FederatedRemoveEmptyTest extends AutomatedTestBase {
for(int k : new int[] {1, 2, 3}) {
Arrays.fill(X3[k], 0);
- if(!rowPartitioned) {
- Arrays.fill(X1[k], 0);
- Arrays.fill(X2[k], 0);
- Arrays.fill(X4[k], 0);
- }
}
MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
index 3e640c0..afd2ffe 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
@@ -54,7 +54,9 @@ public class FederatedSplitTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
- return Arrays.asList(new Object[][] {{152, 12, "TRUE"}, {132, 11, "FALSE"}});
+ return Arrays.asList(new Object[][] {
+ // {152, 12, "TRUE"},
+ {132, 11, "FALSE"}});
}
@Override
@@ -125,9 +127,7 @@ public class FederatedSplitTest extends AutomatedTestBase {
if(cont.equals("TRUE"))
Assert.assertTrue(heavyHittersContainsString("fed_rightIndex"));
else {
- Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
- // TODO add federated diag operator.
- // Assert.assertTrue(heavyHittersContainsString("fed_rdiag"));
+ Assert.assertTrue(heavyHittersContainsString("fed_rmempty"));
}
TestUtils.shutdownThreads(t1, t2);
diff --git a/src/test/scripts/functions/federated/FederatedSplitTest.dml b/src/test/scripts/functions/federated/FederatedSplitTest.dml
index 44c59a9..e1fc647 100644
--- a/src/test/scripts/functions/federated/FederatedSplitTest.dml
+++ b/src/test/scripts/functions/federated/FederatedSplitTest.dml
@@ -24,7 +24,6 @@ X = federated(addresses=list($X1, $X2),
Y = federated(addresses=list($Y1, $Y2),
ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
-
-[Xtr, Xte, Ytr, Yte] = split(X=X,Y=Y,f=0.95, cont=$Cont, seed = 13)
+[Xtr, Xte, Ytr, Yte] = split(X=X, Y=Y, f=0.95, cont=$Cont, seed = 13)
write(Xte, $Z)
print(toString(Xte))
diff --git a/src/test/scripts/functions/federated/FederatedSplitTestReference.dml b/src/test/scripts/functions/federated/FederatedSplitTestReference.dml
index 4db8e1f..962dd84 100644
--- a/src/test/scripts/functions/federated/FederatedSplitTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedSplitTestReference.dml
@@ -21,6 +21,6 @@
X = rbind(read($X1), read($X2))
Y = rbind(read($Y1), read($Y2))
-[Xtr, Xte, Ytr, Yte] = split(X=X,Y=Y, f=0.95 ,cont=$Cont, seed = 13)
+[Xtr, Xte, Ytr, Yte] = split(X=X, Y=Y, f=0.95, cont=$Cont, seed = 13)
write(Xte, $Z)
print(toString(Xte))