You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/09/18 20:17:33 UTC
[systemds] branch master updated: [SYSTEMDS-3101] Fix federated
spoof instruction (federated output)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 82536c1 [SYSTEMDS-3101] Fix federated spoof instruction (federated output)
82536c1 is described below
commit 82536c1841b546db4f519086d2d7a6cba011603c
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Sat Sep 18 22:16:43 2021 +0200
[SYSTEMDS-3101] Fix federated spoof instruction (federated output)
Closes #1380.
Other cleanups:
Closes #1336.
Closes #1365.
---
.../instructions/fed/SpoofFEDInstruction.java | 493 ++++++++++++---------
.../codegen/FederatedCodegenMultipleFedMOTest.java | 6 +-
.../codegen/FederatedOuterProductTmplTest.java | 8 +-
.../codegen/FederatedRowwiseTmplTest.java | 2 +-
.../pipelines/BuiltinTopkEvaluateTest.java | 1 -
5 files changed, 289 insertions(+), 221 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index ecf310c..331ecfc 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -50,6 +50,7 @@ import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.concurrent.Future;
+import java.util.stream.IntStream;
public class SpoofFEDInstruction extends FEDInstruction
{
@@ -82,37 +83,38 @@ public class SpoofFEDInstruction extends FEDInstruction
@Override
public void processInstruction(ExecutionContext ec) {
+ FederationMap fedMap = null;
+ for(CPOperand cpo : _inputs) { // searching for the first federated matrix to obtain the federation map
+ Data tmpData = ec.getVariable(cpo);
+ if(tmpData instanceof MatrixObject && ((MatrixObject)tmpData).isFederatedExcept(FType.BROADCAST)) {
+ fedMap = ((MatrixObject)tmpData).getFedMapping();
+ break;
+ }
+ }
+
Class<?> scla = _op.getClass().getSuperclass();
SpoofFEDType spoofType = null;
if(scla == SpoofCellwise.class)
- spoofType = new SpoofFEDCellwise(_op, _output);
+ spoofType = new SpoofFEDCellwise(_op, _output, fedMap.getType());
else if(scla == SpoofRowwise.class)
- spoofType = new SpoofFEDRowwise(_op, _output);
+ spoofType = new SpoofFEDRowwise(_op, _output, fedMap.getType());
else if(scla == SpoofMultiAggregate.class)
- spoofType = new SpoofFEDMultiAgg(_op, _output);
+ spoofType = new SpoofFEDMultiAgg(_op, _output, fedMap.getType());
else if(scla == SpoofOuterProduct.class)
- spoofType = new SpoofFEDOuterProduct(_op, _output);
+ spoofType = new SpoofFEDOuterProduct(_op, _output, fedMap.getType(), _inputs);
else
throw new DMLRuntimeException("Federated code generation only supported" +
" for cellwise, rowwise, multiaggregate, and outerproduct templates.");
+ processRequest(ec, fedMap, spoofType);
+ }
- FederationMap fedMap = null;
- long id = 0;
- for(CPOperand cpo : _inputs) { // searching for the first federated matrix to obtain the federation map
- Data tmpData = ec.getVariable(cpo);
- if(tmpData instanceof MatrixObject && ((MatrixObject)tmpData).isFederatedExcept(FType.BROADCAST)) {
- fedMap = ((MatrixObject)tmpData).getFedMapping();
- id = ((MatrixObject)tmpData).getUniqueID();
- break;
- }
- }
-
+ private void processRequest(ExecutionContext ec, FederationMap fedMap, SpoofFEDType spoofType) {
ArrayList<FederatedRequest> frBroadcast = new ArrayList<>();
ArrayList<FederatedRequest[]> frBroadcastSliced = new ArrayList<>();
long[] frIds = new long[_inputs.length];
int index = 0;
-
+
for(CPOperand cpo : _inputs) {
Data tmpData = ec.getVariable(cpo);
if(tmpData instanceof MatrixObject) {
@@ -121,7 +123,7 @@ public class SpoofFEDInstruction extends FEDInstruction
frIds[index++] = mo.getFedMapping().getID();
}
else if(spoofType.needsBroadcastSliced(fedMap, mo.getNumRows(), mo.getNumColumns(), index)) {
- FederatedRequest[] tmpFr = spoofType.broadcastSliced(mo, fedMap, id);
+ FederatedRequest[] tmpFr = spoofType.broadcastSliced(mo, fedMap);
frIds[index++] = tmpFr[0].getID();
frBroadcastSliced.add(tmpFr);
}
@@ -144,48 +146,71 @@ public class SpoofFEDInstruction extends FEDInstruction
FederatedRequest frCompute = FederationUtils.callInstruction(instString, _output, _inputs, frIds);
- // get partial results from federated workers
- FederatedRequest frGet = new FederatedRequest(RequestType.GET_VAR, frCompute.getID());
+ FederatedRequest frGet = null;
+ FederatedRequest frCleanup = null;
+ if(!spoofType.isFedOutput()) {
+ // get partial results from federated workers
+ frGet = new FederatedRequest(RequestType.GET_VAR, frCompute.getID());
+ // cleanup the federated request of callInstruction
+ frCleanup = fedMap.cleanup(getTID(), frCompute.getID());
+ }
- ArrayList<FederatedRequest> frCleanup = new ArrayList<>();
- frCleanup.add(fedMap.cleanup(getTID(), frCompute.getID()));
- for(FederatedRequest[] fr : frBroadcastSliced)
- frCleanup.add(fedMap.cleanup(getTID(), fr[0].getID()));
+ FederatedRequest[] frAll;
+ if(frGet == null) // no get request if output is kept federated
+ frAll = ArrayUtils.addAll(
+ frBroadcast.toArray(new FederatedRequest[0]), frCompute);
+ else
+ frAll = ArrayUtils.addAll(
+ frBroadcast.toArray(new FederatedRequest[0]), frCompute, frGet, frCleanup);
- FederatedRequest[] frAll = ArrayUtils.addAll(ArrayUtils.addAll(
- frBroadcast.toArray(new FederatedRequest[0]), frCompute, frGet),
- frCleanup.toArray(new FederatedRequest[0]));
Future<FederatedResponse>[] response = fedMap.executeMultipleSlices(
getTID(), true, frBroadcastSliced.toArray(new FederatedRequest[0][]), frAll);
// setting the output with respect to the different aggregation types
// of the different spoof templates
- spoofType.setOutput(ec, response, fedMap);
+ spoofType.setOutput(ec, response, fedMap, frCompute.getID());
}
+ // abstract class to differentiate between the different spoof templates
private static abstract class SpoofFEDType {
CPOperand _output;
+ FType _fedType;
- protected SpoofFEDType(CPOperand out) {
+ protected SpoofFEDType(CPOperand out, FType fedType) {
_output = out;
+ _fedType = fedType;
}
-
- protected FederatedRequest[] broadcastSliced(MatrixObject mo, FederationMap fedMap, long id) {
+
+ /**
+ * performs the sliced broadcast of the given matrix object
+ *
+ * @param mo the matrix object to broadcast sliced
+ * @param fedMap the federated mapping
+ * @return FederatedRequest[] the resulting federated request array of the broadcast
+ */
+ protected FederatedRequest[] broadcastSliced(MatrixObject mo, FederationMap fedMap) {
return fedMap.broadcastSliced(mo, false);
}
+ /**
+ * determine if a specific matrix object needs to be broadcast sliced
+ *
+ * @param fedMap the federated mapping
+ * @param rowNum the number of rows of the matrix object
+ * @param colNum the number of columns of the matrix object
+ * @param inputIndex the index of the matrix inside the instruction inputs
+ * @return boolean indicates if the matrix needs to be broadcast sliced
+ */
protected boolean needsBroadcastSliced(FederationMap fedMap, long rowNum, long colNum, int inputIndex) {
- FType fedType = fedMap.getType();
-
//TODO fix check by num rows/cols
boolean retVal = (rowNum == fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1));
- if(fedType == FType.ROW)
- retVal |= (rowNum == fedMap.getMaxIndexInRange(0)
- && (colNum == 1 || colNum == fedMap.getSize() || fedMap.getMaxIndexInRange(1) == 1));
- else if(fedType == FType.COL)
+ if(_fedType == FType.ROW)
+ retVal |= (rowNum == fedMap.getMaxIndexInRange(0)
+ && (colNum == 1 || fedMap.getMaxIndexInRange(1) == 1));
+ else if(_fedType == FType.COL)
retVal |= (colNum == fedMap.getMaxIndexInRange(1)
- && (rowNum == 1 || rowNum == fedMap.getSize() || fedMap.getMaxIndexInRange(0) == 1));
+ && (rowNum == 1 || fedMap.getMaxIndexInRange(0) == 1));
else {
throw new DMLRuntimeException("Only row partitioned or column" +
" partitioned federated input supported yet.");
@@ -193,236 +218,281 @@ public class SpoofFEDInstruction extends FEDInstruction
return retVal;
}
- protected abstract void setOutput(ExecutionContext ec,
- Future<FederatedResponse>[] response, FederationMap fedMap);
+ /**
+ * set the output by either calling setFedOutput to keep the output federated
+ * or calling aggResult to aggregate the partial results locally
+ */
+ protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response,
+ FederationMap fedMap, long frComputeID) {
+ if(isFedOutput())
+ setFedOutput(ec, fedMap, frComputeID);
+ else
+ aggResult(ec, response, fedMap);
+ }
+
+ // determine if the output can be kept on the federated sites
+ protected abstract boolean isFedOutput();
+ // set the output by deriving new a federated mapping
+ protected abstract void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID);
+ // aggregate the partial results locally
+ protected abstract void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response,
+ FederationMap fedMap);
}
+ // CELLWISE TEMPLATE
private static class SpoofFEDCellwise extends SpoofFEDType {
private final SpoofCellwise _op;
+ private final CellType _cellType;
- SpoofFEDCellwise(SpoofOperator op, CPOperand out) {
- super(out);
+ SpoofFEDCellwise(SpoofOperator op, CPOperand out, FType fedType) {
+ super(out, fedType);
_op = (SpoofCellwise)op;
+ _cellType = _op.getCellType();
}
- protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
- FType fedType = fedMap.getType();
- AggOp aggOp = ((SpoofCellwise)_op).getAggOp();
- CellType cellType = ((SpoofCellwise)_op).getCellType();
- if(cellType == CellType.FULL_AGG) { // full aggregation
- AggregateUnaryOperator aop = null;
- if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ)
- aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
- else if(aggOp == AggOp.MIN)
- aop = InstructionUtils.parseBasicAggregateUnaryOperator("uamin");
- else if(aggOp == AggOp.MAX)
- aop = InstructionUtils.parseBasicAggregateUnaryOperator("uamax");
- else
- throw new DMLRuntimeException("Aggregation operation not supported yet.");
- ec.setVariable(_output.getName(), FederationUtils.aggScalar(aop, response));
- }
- else if(cellType == CellType.ROW_AGG) { // row aggregation
- if(fedType == FType.ROW) {
- // bind partial results from federated responses
- ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, false));
- }
- else if(fedType == FType.COL) {
- AggregateUnaryOperator aop = null;
- if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ)
- aop = InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
- else if(aggOp == AggOp.MIN)
- aop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmin");
- else if(aggOp == AggOp.MAX)
- aop = InstructionUtils.parseBasicAggregateUnaryOperator("uarmax");
- else
- throw new DMLRuntimeException("Aggregation operation not supported yet.");
- ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
- }
- else {
- throw new DMLRuntimeException("Aggregation type for federated spoof instructions not supported yet.");
- }
- }
- else if(cellType == CellType.COL_AGG) { // col aggregation
- if(fedType == FType.ROW) {
- AggregateUnaryOperator aop = null;
- if(aggOp == AggOp.SUM || aggOp == AggOp.SUM_SQ)
- aop = InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
- else if(aggOp == AggOp.MIN)
- aop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmin");
- else if(aggOp == AggOp.MAX)
- aop = InstructionUtils.parseBasicAggregateUnaryOperator("uacmax");
- else
- throw new DMLRuntimeException("Aggregation operation not supported yet.");
- ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
- }
- else if(fedType == FType.COL) {
- // cbind partial results from federated responses
- ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, true));
- }
- else {
- throw new DMLRuntimeException("Aggregation type for federated spoof instructions not supported yet.");
- }
+ protected boolean isFedOutput() {
+ boolean retVal = false;
+ retVal |= (_cellType == CellType.ROW_AGG && _fedType == FType.ROW);
+ retVal |= (_cellType == CellType.COL_AGG && _fedType == FType.COL);
+ retVal |= (_cellType == CellType.NO_AGG);
+ return retVal;
+ }
+
+ protected void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID) {
+ // derive output federated mapping
+ MatrixObject out = ec.getMatrixObject(_output);
+ FederationMap newFedMap = modifyFedRanges(fedMap.copyWithNewID(frComputeID));
+ out.setFedMapping(newFedMap);
+ }
+
+ private FederationMap modifyFedRanges(FederationMap fedMap) {
+ if(_cellType == CellType.ROW_AGG || _cellType == CellType.COL_AGG) {
+ int dim = (_cellType == CellType.COL_AGG ? 0 : 1);
+ // crop federation map to a vector
+ IntStream.range(0, fedMap.getFederatedRanges().length).forEach(i -> {
+ fedMap.getFederatedRanges()[i].setBeginDim(dim, 0);
+ fedMap.getFederatedRanges()[i].setEndDim(dim, 1);
+ });
}
- else if(cellType == CellType.NO_AGG) { // no aggregation
- if(fedType == FType.ROW) //rbind
- ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, false));
- else if(fedType == FType.COL) //cbind
- ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, true));
- else
- throw new DMLRuntimeException("Only row partitioned or column" +
- " partitioned federated matrices supported yet.");
+ return fedMap;
+ }
+
+ protected void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response,
+ FederationMap fedMap) {
+ AggOp aggOp = _op.getAggOp();
+
+ // build up the instruction for aggregation
+ // (uak+/uamin/uamax/uark+/uarmin/uarmax/uack+/uacmin/uacmax)
+ String aggInst = "ua";
+ switch(_cellType) {
+ case FULL_AGG: break;
+ case ROW_AGG: aggInst += "r"; break;
+ case COL_AGG: aggInst += "c"; break;
+ case NO_AGG:
+ default:
+ throw new DMLRuntimeException("Aggregation type not supported yet.");
}
- else {
- throw new DMLRuntimeException("Aggregation type not supported yet.");
+
+ switch(aggOp) {
+ case SUM:
+ case SUM_SQ: aggInst += "k+"; break;
+ case MIN: aggInst += "min"; break;
+ case MAX: aggInst += "max"; break;
+ default:
+ throw new DMLRuntimeException("Aggregation operation not supported yet.");
}
+
+ AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator(aggInst);
+ if(_cellType == CellType.FULL_AGG)
+ ec.setVariable(_output.getName(), FederationUtils.aggScalar(aop, response));
+ else
+ ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
}
}
+ // ROWWISE TEMPLATE
private static class SpoofFEDRowwise extends SpoofFEDType {
private final SpoofRowwise _op;
+ private final RowType _rowType;
- SpoofFEDRowwise(SpoofOperator op, CPOperand out) {
- super(out);
+ SpoofFEDRowwise(SpoofOperator op, CPOperand out, FType fedType) {
+ super(out, fedType);
_op = (SpoofRowwise)op;
+ _rowType = _op.getRowType();
+ }
+
+ protected boolean isFedOutput() {
+ boolean retVal = false;
+ retVal |= (_rowType == RowType.NO_AGG);
+ retVal |= (_rowType == RowType.NO_AGG_B1);
+ retVal |= (_rowType == RowType.NO_AGG_CONST);
+ retVal &= (_fedType == FType.ROW);
+ return retVal;
}
- protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
- RowType rowType = ((SpoofRowwise)_op).getRowType();
- if(rowType == RowType.FULL_AGG) { // full aggregation
- // aggregate partial results from federated responses as sum
- AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ protected void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID) {
+ // derive output federated mapping
+ MatrixObject out = ec.getMatrixObject(_output);
+ FederationMap newFedMap = modifyFedRanges(fedMap.copyWithNewID(frComputeID), out.getNumColumns());
+ out.setFedMapping(newFedMap);
+ }
+
+ private static FederationMap modifyFedRanges(FederationMap fedMap, long cols) {
+ IntStream.range(0, fedMap.getFederatedRanges().length).forEach(i -> {
+ fedMap.getFederatedRanges()[i].setBeginDim(1, 0);
+ fedMap.getFederatedRanges()[i].setEndDim(1, cols);
+ });
+ return fedMap;
+ }
+
+ protected void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response,
+ FederationMap fedMap) {
+ if(_fedType != FType.ROW)
+ throw new DMLRuntimeException("Only row partitioned federated matrices supported yet.");
+
+ // build up the instruction for aggregation (uak+/uark+/uack+)
+ String aggInst = "ua";
+ if(_rowType == RowType.FULL_AGG) // full aggregation
+ aggInst += "k+";
+ else if(_rowType == RowType.ROW_AGG) // row aggregation
+ aggInst += "rk+";
+ else if(_rowType.isColumnAgg()) // col aggregation
+ aggInst += "ck+";
+ else
+ throw new DMLRuntimeException("AggregationType not supported yet.");
+
+ // aggregate partial results from federated responses as sum/rowSum/colSum
+ AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator(aggInst);
+ if(_rowType == RowType.FULL_AGG)
ec.setVariable(_output.getName(), FederationUtils.aggScalar(aop, response));
- }
- else if(rowType == RowType.ROW_AGG) { // row aggregation
- // aggregate partial results from federated responses as rowSum
- AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uark+");
- ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
- }
- else if(rowType == RowType.COL_AGG
- || rowType == RowType.COL_AGG_T
- || rowType == RowType.COL_AGG_B1
- || rowType == RowType.COL_AGG_B1_T
- || rowType == RowType.COL_AGG_B1R
- || rowType == RowType.COL_AGG_CONST) { // col aggregation
- // aggregate partial results from federated responses as colSum
- AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uack+");
+ else
ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
- }
- else if(rowType == RowType.NO_AGG
- || rowType == RowType.NO_AGG_B1
- || rowType == RowType.NO_AGG_CONST) { // no aggregation
- if(fedMap.getType() == FType.ROW) {
- // bind partial results from federated responses
- ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, false));
- }
- else {
- throw new DMLRuntimeException("Only row partitioned federated matrices supported yet.");
- }
- }
- else {
- throw new DMLRuntimeException("AggregationType not supported yet.");
- }
}
}
+ // MULTIAGGREGATE TEMPLATE
private static class SpoofFEDMultiAgg extends SpoofFEDType {
private final SpoofMultiAggregate _op;
- SpoofFEDMultiAgg(SpoofOperator op, CPOperand out) {
- super(out);
+ SpoofFEDMultiAgg(SpoofOperator op, CPOperand out, FType fedType) {
+ super(out, fedType);
_op = (SpoofMultiAggregate)op;
}
- protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
- MatrixBlock[] partRes = FederationUtils.getResults(response);
- SpoofCellwise.AggOp[] aggOps = ((SpoofMultiAggregate)_op).getAggOps();
- for(int counter = 1; counter < partRes.length; counter++) {
- SpoofMultiAggregate.aggregatePartialResults(aggOps, partRes[0], partRes[counter]);
- }
- ec.setMatrixOutput(_output.getName(), partRes[0]);
+ protected boolean isFedOutput() {
+ return false;
}
- }
+ protected void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID) {
+ throw new DMLRuntimeException("SpoofFEDMultiAgg cannot create a federated output.");
+ }
+ protected void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response,
+ FederationMap fedMap) {
+ MatrixBlock[] partRes = FederationUtils.getResults(response);
+ SpoofCellwise.AggOp[] aggOps = _op.getAggOps();
+ for(int counter = 1; counter < partRes.length; counter++) {
+ SpoofMultiAggregate.aggregatePartialResults(aggOps, partRes[0], partRes[counter]);
+ }
+ ec.setMatrixOutput(_output.getName(), partRes[0]);
+ }
+ }
+
+ // OUTER PRODUCT TEMPLATE
private static class SpoofFEDOuterProduct extends SpoofFEDType {
private final SpoofOuterProduct _op;
+ private final OutProdType _outProdType;
+ private CPOperand[] _inputs;
- SpoofFEDOuterProduct(SpoofOperator op, CPOperand out) {
- super(out);
+ SpoofFEDOuterProduct(SpoofOperator op, CPOperand out, FType fedType, CPOperand[] inputs) {
+ super(out, fedType);
_op = (SpoofOuterProduct)op;
+ _outProdType = _op.getOuterProdType();
+ _inputs = inputs;
+ }
+
+ protected FederatedRequest[] broadcastSliced(MatrixObject mo, FederationMap fedMap) {
+ return fedMap.broadcastSliced(mo, (_fedType == FType.COL));
}
protected boolean needsBroadcastSliced(FederationMap fedMap, long rowNum, long colNum, int inputIndex) {
boolean retVal = false;
- FType fedType = fedMap.getType();
retVal |= (rowNum == fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1));
- if(fedType == FType.ROW)
+ if(_fedType == FType.ROW)
retVal |= (rowNum == fedMap.getMaxIndexInRange(0)) && (inputIndex != 2); // input at index 2 is V
- else if(fedType == FType.COL)
+ else if(_fedType == FType.COL)
retVal |= (rowNum == fedMap.getMaxIndexInRange(1)) && (inputIndex != 1); // input at index 1 is U
else
throw new DMLRuntimeException("Only row partitioned or column" +
" partitioned federated input supported yet.");
-
+
return retVal;
}
- protected void setOutput(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap) {
- FType fedType = fedMap.getType();
- OutProdType outProdType = ((SpoofOuterProduct)_op).getOuterProdType();
- if(outProdType == OutProdType.LEFT_OUTER_PRODUCT) {
- if(fedType == FType.ROW) {
- // aggregate partial results from federated responses as elementwise sum
- AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
- ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
- }
- else if(fedType == FType.COL) {
- // bind partial results from federated responses
- ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, false));
- }
- else {
- throw new DMLRuntimeException("Only row partitioned or column" +
- " partitioned federated matrices supported yet.");
- }
+ protected boolean isFedOutput() {
+ boolean retVal = false;
+ retVal |= (_outProdType == OutProdType.LEFT_OUTER_PRODUCT && _fedType == FType.COL);
+ retVal |= (_outProdType == OutProdType.RIGHT_OUTER_PRODUCT && _fedType == FType.ROW);
+ retVal |= (_outProdType == OutProdType.CELLWISE_OUTER_PRODUCT);
+ return retVal;
+ }
+
+ protected void setFedOutput(ExecutionContext ec, FederationMap fedMap, long frComputeID) {
+ FederationMap newFedMap = fedMap.copyWithNewID(frComputeID);
+ long[] outDims = new long[2];
+
+ // find the resulting output dimensions
+ MatrixObject X = ec.getMatrixObject(_inputs[0]);
+ switch(_outProdType) {
+ case LEFT_OUTER_PRODUCT: // LEFT: nrows of transposed X, ncols of U
+ newFedMap = newFedMap.transpose();
+ outDims[0] = X.getNumColumns();
+ outDims[1] = ec.getMatrixObject(_inputs[1]).getNumColumns();
+ break;
+ case RIGHT_OUTER_PRODUCT: // RIGHT: nrows of X, ncols of V
+ outDims[0] = X.getNumRows();
+ outDims[1] = ec.getMatrixObject(_inputs[2]).getNumColumns();
+ break;
+ case CELLWISE_OUTER_PRODUCT: // BASIC: preserve dimensions of X
+ outDims[0] = X.getNumRows();
+ outDims[1] = X.getNumColumns();
+ break;
+ default:
+ throw new DMLRuntimeException("Outer Product Type " + _outProdType + " not supported yet.");
}
- else if(outProdType == OutProdType.RIGHT_OUTER_PRODUCT) {
- if(fedType == FType.ROW) {
- // bind partial results from federated responses
- ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, false));
- }
- else if(fedType == FType.COL) {
+
+ // derive output federated mapping
+ MatrixObject out = ec.getMatrixObject(_output);
+ int dim = (newFedMap.getType() == FType.ROW ? 1 : 0);
+ newFedMap = modifyFedRanges(newFedMap, dim, outDims[dim]);
+ out.setFedMapping(newFedMap);
+ }
+
+ private static FederationMap modifyFedRanges(FederationMap fedMap, int dim, long value) {
+ IntStream.range(0, fedMap.getFederatedRanges().length).forEach(i -> {
+ fedMap.getFederatedRanges()[i].setBeginDim(dim, 0);
+ fedMap.getFederatedRanges()[i].setEndDim(dim, value);
+ });
+ return fedMap;
+ }
+
+ protected void aggResult(ExecutionContext ec, Future<FederatedResponse>[] response,
+ FederationMap fedMap) {
+ AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
+ switch(_outProdType) {
+ case LEFT_OUTER_PRODUCT:
+ case RIGHT_OUTER_PRODUCT:
// aggregate partial results from federated responses as elementwise sum
- AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
ec.setMatrixOutput(_output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
- }
- else {
- throw new DMLRuntimeException("Only row partitioned or column" +
- " partitioned federated matrices supported yet.");
- }
- }
- else if(outProdType == OutProdType.CELLWISE_OUTER_PRODUCT) {
- if(fedType == FType.ROW) {
- // rbind partial results from federated responses
- ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, false));
- }
- else if(fedType == FType.COL) {
- // cbind partial results from federated responses
- ec.setMatrixOutput(_output.getName(), FederationUtils.bind(response, true));
- }
- else {
- throw new DMLRuntimeException("Only row partitioned or column" +
- " partitioned federated matrices supported yet.");
- }
- }
- else if(outProdType == OutProdType.AGG_OUTER_PRODUCT) {
- // aggregate partial results from federated responses as sum
- AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
- ec.setVariable(_output.getName(), FederationUtils.aggScalar(aop, response));
- }
- else {
- throw new DMLRuntimeException("Outer Product Type " + outProdType + " not supported yet.");
+ break;
+ case AGG_OUTER_PRODUCT:
+ // aggregate partial results from federated responses as sum
+ ec.setVariable(_output.getName(), FederationUtils.aggScalar(aop, response));
+ break;
+ default:
+ throw new DMLRuntimeException("Outer Product Type " + _outProdType + " not supported yet.");
}
}
}
@@ -458,5 +528,4 @@ public class SpoofFEDInstruction extends FEDInstruction
}
return retVal;
}
-
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
index 65f1728..61722db 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCodegenMultipleFedMOTest.java
@@ -104,7 +104,7 @@ public class FederatedCodegenMultipleFedMOTest extends AutomatedTestBase
// row partitioned
// {201, 6, 4, 6, 4, true},
{202, 6, 4, 6, 4, true},
- // {203, 20, 1, 20, 1, true},
+ // FIXME: [SYSTEMDS-3110] {203, 20, 1, 20, 1, true},
// col partitioned
{201, 6, 4, 6, 4, false},
{202, 6, 4, 6, 4, false},
@@ -123,9 +123,9 @@ public class FederatedCodegenMultipleFedMOTest extends AutomatedTestBase
{308, 1000, 2000, 10, 2000, false},
// {310, 1000, 2000, 10, 2000, false},
// row and col partitioned
- // {311, 1000, 2000, 1000, 10, true}, // not working yet - ArrayIndexOutOfBoundsException in dotProduct
+ // {311, 1000, 2000, 1000, 10, true}, // FIXME: ArrayIndexOutOfBoundsException in dotProduct
{312, 1000, 2000, 10, 2000, false},
- // {313, 4000, 2000, 4000, 10, true}, // not working yet - ArrayIndexOutOfBoundsException in dotProduct
+ // {313, 4000, 2000, 4000, 10, true}, // FIXME: ArrayIndexOutOfBoundsException in dotProduct
{314, 4000, 2000, 10, 2000, false},
// combined tests
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
index edc9ab7..cef5fd5 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedOuterProductTmplTest.java
@@ -86,14 +86,14 @@ public class FederatedOuterProductTmplTest extends AutomatedTestBase
{9, 1000, 2000, true},
// column partitioned
- //FIXME {1, 2000, 2000, false},
+ {1, 2000, 2000, false},
// {2, 4000, 2000, false},
// {3, 1000, 1000, false},
- //FIXME {4, 4000, 2000, false},
- //FIXME {5, 4000, 2000, false},
+ {4, 4000, 2000, false},
+ {5, 4000, 2000, false},
// {6, 4000, 2000, false},
//FIXME {7, 2000, 2000, false},
- //FIXME {8, 1000, 2000, false},
+ {8, 1000, 2000, false},
// {9, 1000, 2000, false},
});
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
index b4bff76..89475d8 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
@@ -117,7 +117,7 @@ public class FederatedRowwiseTmplTest extends AutomatedTestBase
}
@Test
- public void federatedCodegenCellwiseHybrid() {
+ public void federatedCodegenRowwiseHybrid() {
testFederatedCodegenRowwise(ExecMode.HYBRID);
}
diff --git a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
index 71160b7..f2e873c 100644
--- a/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/pipelines/BuiltinTopkEvaluateTest.java
@@ -25,7 +25,6 @@ import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.Ignore;
-import org.junit.Test;
public class BuiltinTopkEvaluateTest extends AutomatedTestBase {
// private final static String TEST_NAME1 = "prioritized";