You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/12/13 18:10:41 UTC
[systemds] branch main updated: [SYSTEMDS-3248] Clean AggregateBinaryCPInstruction
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 71116ac [SYSTEMDS-3248] Clean AggregateBinaryCPInstruction
71116ac is described below
commit 71116aca666633859284741878ef16248e1f32dd
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Mon Dec 13 17:58:26 2021 +0100
[SYSTEMDS-3248] Clean AggregateBinaryCPInstruction
This PR cleans AggregateBinaryCPInstruction to isolate Compressed
instructions, and transposed instruction.
A future todo is still to add the rewrite inside the transposed part,
to optimize the multiply if one side is cheap to transpose.
Closes #1482
---
.../cp/AggregateBinaryCPInstruction.java | 87 ++++++++++++++++------
1 file changed, 64 insertions(+), 23 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
index 6981877..934ca7f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateBinaryCPInstruction.java
@@ -22,22 +22,23 @@ package org.apache.sysds.runtime.instructions.cp;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
-import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
public class AggregateBinaryCPInstruction extends BinaryCPInstruction {
// private static final Log LOG = LogFactory.getLog(AggregateBinaryCPInstruction.class.getName());
- public boolean transposeLeft;
- public boolean transposeRight;
+ final public boolean transposeLeft;
+ final public boolean transposeRight;
private AggregateBinaryCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode,
String istr) {
super(CPType.AggregateBinary, op, in1, in2, out, opcode, istr);
+ transposeLeft = false;
+ transposeRight = false;
}
private AggregateBinaryCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode,
@@ -61,17 +62,30 @@ public class AggregateBinaryCPInstruction extends BinaryCPInstruction {
CPOperand out = new CPOperand(parts[3]);
int k = Integer.parseInt(parts[4]);
AggregateBinaryOperator aggbin = InstructionUtils.getMatMultOperator(k);
- if ( numFields == 6 ){
+ if(numFields == 6) {
boolean isLeftTransposed = Boolean.parseBoolean(parts[5]);
boolean isRightTransposed = Boolean.parseBoolean(parts[6]);
return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str, isLeftTransposed,
isRightTransposed);
}
- else return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str);
+ else
+ return new AggregateBinaryCPInstruction(aggbin, in1, in2, out, opcode, str);
}
@Override
public void processInstruction(ExecutionContext ec) {
+ // check compressed inputs
+ final boolean comp1 = ec.getMatrixObject(input1.getName()).isCompressed();
+ final boolean comp2 = ec.getMatrixObject(input2.getName()).isCompressed();
+ if(comp1 || comp2)
+ processCompressedAggregateBinary(ec, comp1, comp2);
+ else if(transposeLeft || transposeRight)
+ processTransposedFusedAggregateBinary(ec);
+ else
+ precessNormal(ec);
+ }
+
+ private void precessNormal(ExecutionContext ec) {
// get inputs
MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
@@ -80,26 +94,53 @@ public class AggregateBinaryCPInstruction extends BinaryCPInstruction {
AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
MatrixBlock ret;
- if(matBlock1 instanceof CompressedMatrixBlock) {
- CompressedMatrixBlock main = (CompressedMatrixBlock) matBlock1;
- ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op, transposeLeft, transposeRight);
+ ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op);
+
+ // release inputs/outputs
+ ec.releaseMatrixInput(input1.getName());
+ ec.releaseMatrixInput(input2.getName());
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+
+ private void processTransposedFusedAggregateBinary(ExecutionContext ec) {
+ MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
+ MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
+ // compute matrix multiplication
+ AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
+ MatrixBlock ret;
+
+ // TODO: Use rewrite rule here t(x) %*% y -> t(t(y) %*% x)
+ if(transposeLeft) {
+ matBlock1 = LibMatrixReorg.transpose(matBlock1, ab_op.getNumThreads());
+ ec.releaseMatrixInput(input1.getName());
}
- else if(matBlock2 instanceof CompressedMatrixBlock) {
- CompressedMatrixBlock main = (CompressedMatrixBlock) matBlock2;
- ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op, transposeLeft, transposeRight);
+ if(transposeRight) {
+ matBlock2 = LibMatrixReorg.transpose(matBlock2, ab_op.getNumThreads());
+ ec.releaseMatrixInput(input2.getName());
+ }
+
+ ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op);
+ ec.releaseMatrixInput(input1.getName());
+ ec.releaseMatrixInput(input2.getName());
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+
+ private void processCompressedAggregateBinary(ExecutionContext ec, boolean c1, boolean c2) {
+ MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
+ MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
+ // compute matrix multiplication
+ AggregateBinaryOperator ab_op = (AggregateBinaryOperator) _optr;
+ MatrixBlock ret;
+
+ if(c1) {
+ CompressedMatrixBlock main = (CompressedMatrixBlock) matBlock1;
+ ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op, transposeLeft,
+ transposeRight);
}
else {
- // todo move rewrite rule here. to do
- // t(x) %*% y -> t(t(y) %*% x)
- if(transposeLeft){
- ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), ab_op.getNumThreads());
- matBlock1 = matBlock1.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
- }
- if(transposeRight){
- ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), ab_op.getNumThreads());
- matBlock2 = matBlock2.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
- }
- ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op);
+ CompressedMatrixBlock main = (CompressedMatrixBlock) matBlock2;
+ ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op, transposeLeft,
+ transposeRight);
}
// release inputs/outputs