You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/12/13 18:10:41 UTC

[systemds] branch main updated: [SYSTEMDS-3248] Clean AggregateBinaryCPInstruction

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 71116ac  [SYSTEMDS-3248] Clean AggregateBinaryCPInstruction
71116ac is described below

commit 71116aca666633859284741878ef16248e1f32dd
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Mon Dec 13 17:58:26 2021 +0100

    [SYSTEMDS-3248] Clean AggregateBinaryCPInstruction
    
    This PR cleans AggregateBinaryCPInstruction to isolate Compressed
    instructions, and transposed instruction.
    A future todo is still to add the rewrite inside the transposed part,
    to optimize the multiply if one side is cheap to transpose.
    
    Closes #1482
---
 .../cp/AggregateBinaryCPInstruction.java           | 87 ++++++++++++++++------
 1 file changed, 64 insertions(+), 23 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
index 6981877..934ca7f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
@@ -22,22 +22,23 @@ package org.apache.sysds.runtime.instructions.cp;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
-import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 
 public class AggregateBinaryCPInstruction extends BinaryCPInstruction {
 	// private static final Log LOG = LogFactory.getLog(AggregateBinaryCPInstruction.class.getName());
 
-	public boolean transposeLeft;
-	public boolean transposeRight;
+	final public boolean transposeLeft;
+	final public boolean transposeRight;
 
 	private AggregateBinaryCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode,
 		String istr) {
 		super(CPType.AggregateBinary, op, in1, in2, out, opcode, istr);
+		transposeLeft = false;
+		transposeRight = false;
 	}
 
 	private AggregateBinaryCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode,
@@ -61,17 +62,30 @@ public class AggregateBinaryCPInstruction extends BinaryCPInstruction {
 		CPOperand out = new CPOperand(parts[3]);
 		int k = Integer.parseInt(parts[4]);
 		AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(k);
-		if ( numFields == 6 ){
+		if(numFields == 6) {
 			boolean isLeftTransposed = Boolean.parseBoolean(parts[5]);
 			boolean isRightTransposed = Boolean.parseBoolean(parts[6]);
 			return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str, isLeftTransposed,
 				isRightTransposed);
 		}
-		else return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str);
+		else
+			return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str);
 	}
 
 	@Override
 	public void processInstruction(ExecutionContext ec) {
+		// check compressed inputs
+		final boolean comp1 = ec.getMatrixObject(input1.getName()).isCompressed();
+		final boolean comp2 = ec.getMatrixObject(input2.getName()).isCompressed();
+		if(comp1 || comp2)
+			processCompressedAggregateBinary(ec, comp1, comp2);
+		else if(transposeLeft || transposeRight)
+			processTransposedFusedAggregateBinary(ec);
+		else
+			precessNormal(ec);
+	}
+
+	private void precessNormal(ExecutionContext ec) {
 		// get inputs
 		MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
 		MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
@@ -80,26 +94,53 @@ public class AggregateBinaryCPInstruction extends BinaryCPInstruction {
 		AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
 		MatrixBlock ret;
 
-		if(matBlock1 instanceof CompressedMatrixBlock) {
-			CompressedMatrixBlock main = (CompressedMatrixBlock) matBlock1;
-			ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op, transposeLeft, transposeRight);
+		ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op);
+
+		// release inputs/outputs
+		ec.releaseMatrixInput(input1.getName());
+		ec.releaseMatrixInput(input2.getName());
+		ec.setMatrixOutput(output.getName(), ret);
+	}
+
+	private void processTransposedFusedAggregateBinary(ExecutionContext ec) {
+		MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
+		MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
+		// compute matrix multiplication
+		AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
+		MatrixBlock ret;
+
+		// TODO: Use rewrite rule here t(x) %*% y -> t(t(y) %*% x)
+		if(transposeLeft) {
+			matBlock1 = LibMatrixReorg.transpose(matBlock1, ab_op.getNumThreads());
+			ec.releaseMatrixInput(input1.getName());
 		}
-		else if(matBlock2 instanceof CompressedMatrixBlock) {
-			CompressedMatrixBlock main = (CompressedMatrixBlock) matBlock2;
-			ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op, transposeLeft, transposeRight);
+		if(transposeRight) {
+			matBlock2 = LibMatrixReorg.transpose(matBlock2, ab_op.getNumThreads());
+			ec.releaseMatrixInput(input2.getName());
+		}
+
+		ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op);
+		ec.releaseMatrixInput(input1.getName());
+		ec.releaseMatrixInput(input2.getName());
+		ec.setMatrixOutput(output.getName(), ret);
+	}
+
+	private void processCompressedAggregateBinary(ExecutionContext ec, boolean c1, boolean c2) {
+		MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
+		MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
+		// compute matrix multiplication
+		AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
+		MatrixBlock ret;
+
+		if(c1) {
+			CompressedMatrixBlock main = (CompressedMatrixBlock) matBlock1;
+			ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op, transposeLeft,
+				transposeRight);
 		}
 		else {
-			// todo move rewrite rule here. to do 
-			// t(x) %*% y -> t(t(y) %*% x)
-			if(transposeLeft){
-				ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), ab_op.getNumThreads());
-				matBlock1 = matBlock1.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
-			}
-			if(transposeRight){
-				ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), ab_op.getNumThreads());
-				matBlock2 = matBlock2.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
-			}
-			ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op);
+			CompressedMatrixBlock main = (CompressedMatrixBlock) matBlock2;
+			ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op, transposeLeft,
+				transposeRight);
 		}
 
 		// release inputs/outputs