You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/09/18 20:26:07 UTC
[systemds] branch master updated: [SYSTEMDS-3086] Fix federated
wdivmm operations (federated output)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 2e57d6d [SYSTEMDS-3086] Fix federated wdivmm operations (federated output)
2e57d6d is described below
commit 2e57d6d9fe066bf52c9499c09871357b7874416c
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Sat Sep 18 22:24:44 2021 +0200
[SYSTEMDS-3086] Fix federated wdivmm operations (federated output)
Closes #1381.
---
.../fed/QuaternaryWDivMMFEDInstruction.java | 97 +++++++++++++++++-----
1 file changed, 76 insertions(+), 21 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
index a47e5d9..414a4ff 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
@@ -37,11 +37,11 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
-import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
import java.util.ArrayList;
import java.util.concurrent.Future;
+import java.util.stream.IntStream;
public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction
{
@@ -60,32 +60,35 @@ public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction
* @param out The Federated Result Z
* @param opcode ...
* @param instruction_str ...
- */
- protected QuaternaryWDivMMFEDInstruction(Operator operator,
+ */
+
+ private QuaternaryOperator _qop;
+
+ protected QuaternaryWDivMMFEDInstruction(QuaternaryOperator operator,
CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String instruction_str)
{
super(FEDType.Quaternary, operator, in1, in2, in3, in4, out, opcode, instruction_str);
+ _qop = operator;
}
@Override
public void processInstruction(ExecutionContext ec)
{
- QuaternaryOperator qop = (QuaternaryOperator) _optr;
- final WDivMMType wdivmm_type = qop.wtype3;
+ final WDivMMType wdivmm_type = _qop.wtype3;
MatrixObject X = ec.getMatrixObject(input1);
MatrixObject U = ec.getMatrixObject(input2);
MatrixObject V = ec.getMatrixObject(input3);
ScalarObject eps = null;
MatrixObject MX = null;
- if(qop.hasFourInputs()) {
+ if(_qop.hasFourInputs()) {
if(wdivmm_type == WDivMMType.MULT_MINUS_4_LEFT || wdivmm_type == WDivMMType.MULT_MINUS_4_RIGHT) {
MX = ec.getMatrixObject(_input4);
}
else {
eps = (_input4.getDataType() == DataType.SCALAR) ?
ec.getScalarInput(_input4) :
- new DoubleObject(ec.getMatrixInput(_input4.getName()).quickGetValue(0, 0));
+ new DoubleObject(ec.getMatrixInput(_input4).quickGetValue(0, 0));
}
}
@@ -93,7 +96,7 @@ public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction
FederationMap fedMap = X.getFedMapping();
ArrayList<FederatedRequest[]> frSliced = new ArrayList<>();
ArrayList<FederatedRequest> frB = new ArrayList<>(); // FederatedRequests of broadcasts
- long[] varNewIn = new long[qop.hasFourInputs() ? 4 : 3];
+ long[] varNewIn = new long[_qop.hasFourInputs() ? 4 : 3];
varNewIn[0] = fedMap.getID();
if(X.isFederated(FType.ROW)) { // row partitioned X
@@ -151,18 +154,23 @@ public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction
}
FederatedRequest frComp = FederationUtils.callInstruction(instString, output,
- qop.hasFourInputs() ? new CPOperand[]{input1, input2, input3, _input4}
+ _qop.hasFourInputs() ? new CPOperand[]{input1, input2, input3, _input4}
: new CPOperand[]{input1, input2, input3}, varNewIn);
- // get partial results from federated workers
- FederatedRequest frGet = new FederatedRequest(RequestType.GET_VAR, frComp.getID());
+ FederatedRequest frGet = null;
+
+ FederatedRequest frC = null;
+ if((wdivmm_type.isLeft() && X.isFederated(FType.ROW))
+ || (wdivmm_type.isRight() && X.isFederated(FType.COL))) { // output needs local aggregation
+ // get partial results from federated workers
+ frGet = new FederatedRequest(RequestType.GET_VAR, frComp.getID());
+ // cleanup the federated request of the instruction call
+ frC = fedMap.cleanup(getTID(), frComp.getID());
+ }
- ArrayList<FederatedRequest> frC = new ArrayList<>();
- frC.add(fedMap.cleanup(getTID(), frComp.getID()));
-
- FederatedRequest[] frAll = ArrayUtils.addAll(ArrayUtils.addAll(
- frB.toArray(new FederatedRequest[0]), frComp, frGet),
- frC.toArray(new FederatedRequest[0]));
+ FederatedRequest[] frAll = (frGet == null ?
+ ArrayUtils.addAll(frB.toArray(new FederatedRequest[0]), frComp)
+ : ArrayUtils.addAll(frB.toArray(new FederatedRequest[0]), frComp, frGet, frC));
// execute federated instructions
Future<FederatedResponse>[] response = frSliced.isEmpty() ?
@@ -170,14 +178,13 @@ public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction
getTID(), true, frSliced.toArray(new FederatedRequest[0][]), frAll);
if((wdivmm_type.isLeft() && X.isFederated(FType.ROW))
- || (wdivmm_type.isRight() && X.isFederated(FType.COL))) {
+ || (wdivmm_type.isRight() && X.isFederated(FType.COL))) { // local aggregation
// aggregate partial results from federated responses
AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
}
else if(wdivmm_type.isLeft() || wdivmm_type.isRight() || wdivmm_type.isBasic()) {
- // bind partial results from federated responses
- ec.setMatrixOutput(output.getName(), FederationUtils.bind(response, false));
+ setFederatedOutput(X, U, V, ec, frComp.getID());
}
else {
throw new DMLRuntimeException("Federated WDivMM only supported for BASIC, LEFT or RIGHT variants.");
@@ -188,5 +195,53 @@ public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction
+ X.isFederated() + ", " + U.isFederated() + ", " + V.isFederated() + ")");
}
}
-}
+ /**
+ * Set the federated output according to the output data charactersitics of
+ * the different wdivmm types
+ */
+ private void setFederatedOutput(MatrixObject X, MatrixObject U, MatrixObject V, ExecutionContext ec, long fedMapID) {
+ final WDivMMType wdivmm_type = _qop.wtype3;
+ MatrixObject out = ec.getMatrixObject(output);
+ FederationMap outFedMap = X.getFedMapping().copyWithNewID(fedMapID);
+
+ long rows = -1;
+ long cols = -1;
+ if(wdivmm_type.isBasic()) {
+ // BASIC: preserve dimensions of X
+ rows = X.getNumRows();
+ cols = X.getNumColumns();
+ }
+ else if(wdivmm_type.isLeft()) {
+ // LEFT: nrows of transposed X, ncols of U
+ rows = X.getNumColumns();
+ cols = U.getNumColumns();
+ outFedMap = modifyFedRanges(outFedMap.transpose(), cols, 1);
+ }
+ else if(wdivmm_type.isRight()) {
+ // RIGHT: nrows of X, ncols of V
+ rows = X.getNumRows();
+ cols = V.getNumColumns();
+ outFedMap = modifyFedRanges(outFedMap, cols, 1);
+ }
+ out.setFedMapping(outFedMap);
+ out.getDataCharacteristics().set(rows, cols, (int) X.getBlocksize());
+ }
+
+ /**
+ * Takes the federated mapping and sets one dimension of all federated ranges
+ * to the specified value.
+ *
+ * @param fedMap the original federated mapping
+ * @param value long value for setting the dimension
+ * @param dim indicates if the row (0) or column (1) dimension should be set to value
+ * @return FederationMap with the modified federated ranges
+ */
+ private static FederationMap modifyFedRanges(FederationMap fedMap, long value, int dim) {
+ IntStream.range(0, fedMap.getFederatedRanges().length).forEach(i -> {
+ fedMap.getFederatedRanges()[i].setBeginDim(dim, 0);
+ fedMap.getFederatedRanges()[i].setEndDim(dim, value);
+ });
+ return fedMap;
+ }
+}