You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/09/18 20:26:07 UTC

[systemds] branch master updated: [SYSTEMDS-3086] Fix federated wdivmm operations (federated output)

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 2e57d6d  [SYSTEMDS-3086] Fix federated wdivmm operations (federated output)
2e57d6d is described below

commit 2e57d6d9fe066bf52c9499c09871357b7874416c
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Sat Sep 18 22:24:44 2021 +0200

    [SYSTEMDS-3086] Fix federated wdivmm operations (federated output)
    
    Closes #1381.
---
 .../fed/QuaternaryWDivMMFEDInstruction.java        | 97 +++++++++++++++++-----
 1 file changed, 76 insertions(+), 21 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
index a47e5d9..414a4ff 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.java
@@ -37,11 +37,11 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
-import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
 
 import java.util.ArrayList;
 import java.util.concurrent.Future;
+import java.util.stream.IntStream;
 
 public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction
 {
@@ -60,32 +60,35 @@ public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction
 	 * @param out             The Federated Result Z
 	 * @param opcode          ...
 	 * @param instruction_str ...
-	 */
-	protected QuaternaryWDivMMFEDInstruction(Operator operator,
+	*/
+
+	private QuaternaryOperator _qop;
+
+	protected QuaternaryWDivMMFEDInstruction(QuaternaryOperator operator,
 		CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String instruction_str)
 	{
 		super(FEDType.Quaternary, operator, in1, in2, in3, in4, out, opcode, instruction_str);
+		_qop = operator;
 	}
 
 	@Override
 	public void processInstruction(ExecutionContext ec)
 	{
-		QuaternaryOperator qop = (QuaternaryOperator) _optr;
-		final WDivMMType wdivmm_type = qop.wtype3;
+		final WDivMMType wdivmm_type = _qop.wtype3;
 		MatrixObject X = ec.getMatrixObject(input1);
 		MatrixObject U = ec.getMatrixObject(input2);
 		MatrixObject V = ec.getMatrixObject(input3);
 		ScalarObject eps = null;
 		MatrixObject MX = null;
 
-		if(qop.hasFourInputs()) {
+		if(_qop.hasFourInputs()) {
 			if(wdivmm_type == WDivMMType.MULT_MINUS_4_LEFT || wdivmm_type == WDivMMType.MULT_MINUS_4_RIGHT) {
 				MX = ec.getMatrixObject(_input4);
 			}
 			else {
 				eps = (_input4.getDataType() == DataType.SCALAR) ?
 					ec.getScalarInput(_input4) :
-					new DoubleObject(ec.getMatrixInput(_input4.getName()).quickGetValue(0, 0));
+					new DoubleObject(ec.getMatrixInput(_input4).quickGetValue(0, 0));
 			}
 		}
 
@@ -93,7 +96,7 @@ public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction
 			FederationMap fedMap = X.getFedMapping();
 			ArrayList<FederatedRequest[]> frSliced = new ArrayList<>();
 			ArrayList<FederatedRequest> frB = new ArrayList<>(); // FederatedRequests of broadcasts
-			long[] varNewIn = new long[qop.hasFourInputs() ? 4 : 3];
+			long[] varNewIn = new long[_qop.hasFourInputs() ? 4 : 3];
 			varNewIn[0] = fedMap.getID();
 
 			if(X.isFederated(FType.ROW)) { // row partitioned X
@@ -151,18 +154,23 @@ public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction
 			}
 
 			FederatedRequest frComp = FederationUtils.callInstruction(instString, output,
-				qop.hasFourInputs() ? new CPOperand[]{input1, input2, input3, _input4}
+				_qop.hasFourInputs() ? new CPOperand[]{input1, input2, input3, _input4}
 				: new CPOperand[]{input1, input2, input3}, varNewIn);
 
-			// get partial results from federated workers
-			FederatedRequest frGet = new FederatedRequest(RequestType.GET_VAR, frComp.getID());
+			FederatedRequest frGet = null;
+
+			FederatedRequest frC = null;
+			if((wdivmm_type.isLeft() && X.isFederated(FType.ROW))
+				|| (wdivmm_type.isRight() && X.isFederated(FType.COL))) { // output needs local aggregation
+				// get partial results from federated workers
+				frGet = new FederatedRequest(RequestType.GET_VAR, frComp.getID());
+				// cleanup the federated request of the instruction call
+				frC = fedMap.cleanup(getTID(), frComp.getID());
+			}
 
-			ArrayList<FederatedRequest> frC = new ArrayList<>();
-			frC.add(fedMap.cleanup(getTID(), frComp.getID()));
-			
-			FederatedRequest[] frAll = ArrayUtils.addAll(ArrayUtils.addAll(
-				frB.toArray(new FederatedRequest[0]), frComp, frGet),
-				frC.toArray(new FederatedRequest[0]));
+			FederatedRequest[] frAll = (frGet == null ?
+					ArrayUtils.addAll(frB.toArray(new FederatedRequest[0]), frComp)
+					: ArrayUtils.addAll(frB.toArray(new FederatedRequest[0]), frComp, frGet, frC));
 
 			// execute federated instructions
 			Future<FederatedResponse>[] response = frSliced.isEmpty() ?
@@ -170,14 +178,13 @@ public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction
 					getTID(), true, frSliced.toArray(new FederatedRequest[0][]), frAll);
 
 			if((wdivmm_type.isLeft() && X.isFederated(FType.ROW))
-				|| (wdivmm_type.isRight() && X.isFederated(FType.COL))) {
+				|| (wdivmm_type.isRight() && X.isFederated(FType.COL))) { // local aggregation
 				// aggregate partial results from federated responses
 				AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
 				ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, response, fedMap));
 			}
 			else if(wdivmm_type.isLeft() || wdivmm_type.isRight() || wdivmm_type.isBasic()) {
-				// bind partial results from federated responses
-				ec.setMatrixOutput(output.getName(), FederationUtils.bind(response, false));
+				setFederatedOutput(X, U, V, ec, frComp.getID());
 			}
 			else {
 				throw new DMLRuntimeException("Federated WDivMM only supported for BASIC, LEFT or RIGHT variants.");
@@ -188,5 +195,53 @@ public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction
 				+ X.isFederated() + ", " + U.isFederated() + ", " + V.isFederated() + ")");
 		}
 	}
-}
 
+	/**
+	 * Set the federated output according to the output data charactersitics of
+	 * the different wdivmm types
+	 */
+	private void setFederatedOutput(MatrixObject X, MatrixObject U, MatrixObject V, ExecutionContext ec, long fedMapID) {
+		final WDivMMType wdivmm_type = _qop.wtype3;
+		MatrixObject out = ec.getMatrixObject(output);
+		FederationMap outFedMap = X.getFedMapping().copyWithNewID(fedMapID);
+
+		long rows = -1;
+		long cols = -1;
+		if(wdivmm_type.isBasic()) {
+			// BASIC: preserve dimensions of X
+			rows = X.getNumRows();
+			cols = X.getNumColumns();
+		}
+		else if(wdivmm_type.isLeft()) {
+			// LEFT: nrows of transposed X, ncols of U
+			rows = X.getNumColumns();
+			cols = U.getNumColumns();
+			outFedMap = modifyFedRanges(outFedMap.transpose(), cols, 1);
+		}
+		else if(wdivmm_type.isRight()) {
+			// RIGHT: nrows of X, ncols of V
+			rows = X.getNumRows();
+			cols = V.getNumColumns();
+			outFedMap = modifyFedRanges(outFedMap, cols, 1);
+		}
+		out.setFedMapping(outFedMap);
+		out.getDataCharacteristics().set(rows, cols, (int) X.getBlocksize());
+	}
+
+	/**
+	 * Takes the federated mapping and sets one dimension of all federated ranges
+	 * to the specified value.
+	 *
+	 * @param fedMap     the original federated mapping
+	 * @param value      long value for setting the dimension
+	 * @param dim        indicates if the row (0) or column (1) dimension should be set to value
+	 * @return FederationMap with the modified federated ranges
+	 */
+	private static FederationMap modifyFedRanges(FederationMap fedMap, long value, int dim) {
+		IntStream.range(0, fedMap.getFederatedRanges().length).forEach(i -> {
+			fedMap.getFederatedRanges()[i].setBeginDim(dim, 0);
+			fedMap.getFederatedRanges()[i].setEndDim(dim, value);
+		});
+		return fedMap;
+	}
+}