You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2022/01/16 13:06:47 UTC

[systemds] 02/03: [SYSTEMDS-3243] Compressed Matrix Multiplication part

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit a98560cb306012771dce215bda150a89dd9bf482
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Sun Jan 16 14:02:08 2022 +0100

    [SYSTEMDS-3243] Compressed Matrix Multiplication part
    
    This commit follow the previous by modifying the compression tests
    and compression path for Matrix Multiplcation to fit with the design of
    the normal MatrixBlock.
    
    Closes #1480
---
 .../runtime/compress/CompressedMatrixBlock.java    | 105 +----------------
 .../runtime/compress/lib/CLALibMatrixMult.java     | 128 +++++++++++++++++++++
 .../component/compress/CompressedTestBase.java     |   3 +-
 .../test/component/estim/OpBindChainTest.java      |   4 +-
 .../test/component/estim/OpElemWChainTest.java     |   4 +-
 .../component/estim/SquaredProductChainTest.java   |   2 +-
 6 files changed, 140 insertions(+), 106 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index 85cc23b..ec09226 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -51,8 +51,8 @@ import org.apache.sysds.runtime.compress.lib.CLALibCompAgg;
 import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
 import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy;
 import org.apache.sysds.runtime.compress.lib.CLALibMMChain;
+import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
 import org.apache.sysds.runtime.compress.lib.CLALibReExpand;
-import org.apache.sysds.runtime.compress.lib.CLALibRightMultBy;
 import org.apache.sysds.runtime.compress.lib.CLALibScalar;
 import org.apache.sysds.runtime.compress.lib.CLALibSlice;
 import org.apache.sysds.runtime.compress.lib.CLALibSquash;
@@ -61,13 +61,11 @@ import org.apache.sysds.runtime.compress.lib.CLALibUtils;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
-import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.data.SparseRow;
 import org.apache.sysds.runtime.functionobjects.MinusMultiply;
 import org.apache.sysds.runtime.functionobjects.PlusMultiply;
-import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
@@ -76,7 +74,6 @@ import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
 import org.apache.sysds.runtime.matrix.data.CTableMap;
 import org.apache.sysds.runtime.matrix.data.IJV;
 import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen;
-import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.LibMatrixTercell;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
@@ -471,105 +468,15 @@ public class CompressedMatrixBlock extends MatrixBlock {
 	}
 
 	@Override
-	public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret,
-		AggregateBinaryOperator op) {
-		// create output matrix block
-		return aggregateBinaryOperations(m1, m2, ret, op, false, false);
+	public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
+		checkAggregateBinaryOperations(m1, m2, op);
+		return CLALibMatrixMult.matrixMultiply(m1, m2, ret, op.getNumThreads(), false, false);
 	}
 
 	public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret,
 		AggregateBinaryOperator op, boolean transposeLeft, boolean transposeRight) {
-		validateMatrixMult(m1, m2);
-		final int k = op.getNumThreads();
-		final Timing time = LOG.isTraceEnabled() ? new Timing(true) : null;
-
-		if(m1 instanceof CompressedMatrixBlock && m2 instanceof CompressedMatrixBlock) {
-			return doubleCompressedAggregateBinaryOperations((CompressedMatrixBlock) m1, (CompressedMatrixBlock) m2, ret,
-				op, transposeLeft, transposeRight);
-		}
-		boolean transposeOutput = false;
-		if(transposeLeft || transposeRight) {
-
-			if((m1 instanceof CompressedMatrixBlock && transposeLeft) ||
-				(m2 instanceof CompressedMatrixBlock && transposeRight)) {
-				// change operation from m1 %*% m2 -> t( t(m2) %*% t(m1) )
-				transposeOutput = true;
-				MatrixBlock tmp = m1;
-				m1 = m2;
-				m2 = tmp;
-				boolean tmpLeft = transposeLeft;
-				transposeLeft = !transposeRight;
-				transposeRight = !tmpLeft;
-
-			}
-
-			if(!(m1 instanceof CompressedMatrixBlock) && transposeLeft) {
-				m1 = LibMatrixReorg.transpose(m1, k);
-				transposeLeft = false;
-			}
-			else if(!(m2 instanceof CompressedMatrixBlock) && transposeRight) {
-				m2 = LibMatrixReorg.transpose(m2, k);
-				transposeRight = false;
-			}
-		}
-
-		final boolean right = (m1 == this);
-		final MatrixBlock that = right ? m2 : m1;
-
-		// create output matrix block
-		if(right)
-			ret = CLALibRightMultBy.rightMultByMatrix(this, that, ret, op.getNumThreads());
-		else
-			ret = CLALibLeftMultBy.leftMultByMatrix(this, that, ret, op.getNumThreads());
-
-		if(LOG.isTraceEnabled())
-			LOG.trace("MM: Time block w/ sharedDim: " + m1.getNumColumns() + " rowLeft: " + m1.getNumRows() + " colRight:"
-				+ m2.getNumColumns() + " in " + time.stop() + "ms.");
-
-		if(transposeOutput) {
-			if(ret instanceof CompressedMatrixBlock) {
-				LOG.warn("Transposing decompression");
-				ret = ((CompressedMatrixBlock) ret).decompress(k);
-			}
-			ret = LibMatrixReorg.transpose(ret, k);
-		}
-
-		return ret;
-	}
-
-	private void validateMatrixMult(MatrixBlock m1, MatrixBlock m2) {
-		if(!(m1 == this || m2 == this))
-			throw new DMLRuntimeException("Invalid aggregateBinaryOperation One of either input should be this");
-	}
-
-	private MatrixBlock doubleCompressedAggregateBinaryOperations(CompressedMatrixBlock m1, CompressedMatrixBlock m2,
-		MatrixBlock ret, AggregateBinaryOperator op, boolean transposeLeft, boolean transposeRight) {
-		if(!transposeLeft && !transposeRight) {
-			// If both are not transposed, decompress the right hand side. to enable
-			// compressed overlapping output.
-			LOG.warn("Matrix decompression from multiplying two compressed matrices.");
-			return aggregateBinaryOperations(m1, getUncompressed(m2), ret, op, transposeLeft, transposeRight);
-		}
-		else if(transposeLeft && !transposeRight) {
-			if(m1.getNumColumns() > m2.getNumColumns()) {
-				ret = CLALibLeftMultBy.leftMultByMatrixTransposed(m1, m2, ret, op.getNumThreads());
-				ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), op.getNumThreads());
-				return ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
-			}
-			else
-				return CLALibLeftMultBy.leftMultByMatrixTransposed(m2, m1, ret, op.getNumThreads());
-
-		}
-		else if(!transposeLeft && transposeRight) {
-			throw new DMLCompressionException("Not Implemented compressed Matrix Mult, to produce larger matrix");
-			// worst situation since it blows up the result matrix in number of rows in
-			// either compressed matrix.
-		}
-		else {
-			ret = aggregateBinaryOperations(m2, m1, ret, op);
-			ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), op.getNumThreads());
-			return ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
-		}
+		checkAggregateBinaryOperations(m1, m2, op, transposeLeft, transposeRight);
+		return CLALibMatrixMult.matrixMultiply(m1, m2, ret, op.getNumThreads(), transposeLeft, transposeRight);
 	}
 
 	@Override
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java
new file mode 100644
index 0000000..941338d
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.lib;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.functionobjects.SwapIndex;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+
+public class CLALibMatrixMult {
+	private static final Log LOG = LogFactory.getLog(CLALibMatrixMult.class.getName());
+
+	public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
+		return matrixMultiply(m1, m2, ret, k, false, false);
+	}
+
+	public static MatrixBlock matrixMultiply(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret,
+		int k, boolean transposeLeft, boolean transposeRight) {
+		final Timing time = LOG.isTraceEnabled() ? new Timing(true) : null;
+
+		if(m1 instanceof CompressedMatrixBlock && m2 instanceof CompressedMatrixBlock) {
+			return doubleCompressedMatrixMultiply((CompressedMatrixBlock) m1, (CompressedMatrixBlock) m2, ret,
+				k, transposeLeft, transposeRight);
+		}
+
+		boolean transposeOutput = false;
+		if(transposeLeft || transposeRight) {
+
+			if((m1 instanceof CompressedMatrixBlock && transposeLeft) ||
+				(m2 instanceof CompressedMatrixBlock && transposeRight)) {
+				// change operation from m1 %*% m2 -> t( t(m2) %*% t(m1) )
+				transposeOutput = true;
+				MatrixBlock tmp = m1;
+				m1 = m2;
+				m2 = tmp;
+				boolean tmpLeft = transposeLeft;
+				transposeLeft = !transposeRight;
+				transposeRight = !tmpLeft;
+			}
+
+			if(!(m1 instanceof CompressedMatrixBlock) && transposeLeft) {
+				m1 = LibMatrixReorg.transpose(m1, k);
+				transposeLeft = false;
+			}
+			else if(!(m2 instanceof CompressedMatrixBlock) && transposeRight) {
+				m2 = LibMatrixReorg.transpose(m2, k);
+				transposeRight = false;
+			}
+		}
+
+		final boolean right = (m1 instanceof CompressedMatrixBlock);
+		final CompressedMatrixBlock c =(CompressedMatrixBlock) (right ? m1 : m2);
+		final MatrixBlock that = right ? m2 : m1;
+
+		// create output matrix block
+		if(right)
+			ret = CLALibRightMultBy.rightMultByMatrix(c, that, ret, k);
+		else
+			ret = CLALibLeftMultBy.leftMultByMatrix(c, that, ret, k);
+
+		if(LOG.isTraceEnabled())
+			LOG.trace("MM: Time block w/ sharedDim: " + m1.getNumColumns() + " rowLeft: " + m1.getNumRows() + " colRight:"
+				+ m2.getNumColumns() + " in " + time.stop() + "ms.");
+
+		if(transposeOutput) {
+			if(ret instanceof CompressedMatrixBlock) {
+				LOG.warn("Transposing decompression");
+				ret = ((CompressedMatrixBlock) ret).decompress(k);
+			}
+			ret = LibMatrixReorg.transpose(ret, k);
+		}
+
+		return ret;
+	}
+
+	private static MatrixBlock doubleCompressedMatrixMultiply(CompressedMatrixBlock m1, CompressedMatrixBlock m2,
+		MatrixBlock ret, int k, boolean transposeLeft, boolean transposeRight) {
+		if(!transposeLeft && !transposeRight) {
+			// If both are not transposed, decompress the right hand side. to enable
+			// compressed overlapping output.
+			LOG.warn("Matrix decompression from multiplying two compressed matrices.");
+			return matrixMultiply(m1, CompressedMatrixBlock.getUncompressed(m2), ret, k, transposeLeft, transposeRight);
+		}
+		else if(transposeLeft && !transposeRight) {
+			if(m1.getNumColumns() > m2.getNumColumns()) {
+				ret = CLALibLeftMultBy.leftMultByMatrixTransposed(m1, m2, ret, k);
+				ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k);
+				return ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
+			}
+			else
+				return CLALibLeftMultBy.leftMultByMatrixTransposed(m2, m1, ret, k);
+
+		}
+		else if(!transposeLeft && transposeRight) {
+			throw new DMLCompressionException("Not Implemented compressed Matrix Mult, to produce larger matrix");
+			// worst situation since it blows up the result matrix in number of rows in
+			// either compressed matrix.
+		}
+		else {
+			ret = CLALibMatrixMult.matrixMult(m2, m1, ret, k);
+			ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k);
+			return ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
+		}
+	}
+
+}
diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
index d310e27..5d201c0 100644
--- a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
@@ -621,7 +621,6 @@ public abstract class CompressedTestBase extends TestBase {
 				return; // Early termination since the test does not test what we wanted.
 
 			// Make Operator
-			AggregateBinaryOperator abop = InstructionUtils.getMatMultOperator(_k);
 			AggregateBinaryOperator abopSingle = InstructionUtils.getMatMultOperator(1);
 
 			// vector-matrix uncompressed
@@ -633,7 +632,7 @@ public abstract class CompressedTestBase extends TestBase {
 			ucRet = right.aggregateBinaryOperations(left, right, ucRet, abopSingle);
 
 			MatrixBlock ret2 = ((CompressedMatrixBlock) cmb).aggregateBinaryOperations(compMatrix, cmb, new MatrixBlock(),
-				abop, transposeLeft, transposeRight);
+			abopSingle, transposeLeft, transposeRight);
 
 			compareResultMatrices(ucRet, ret2, 100);
 		}
diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java
index f75ba17..12c66cf 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java
@@ -136,7 +136,7 @@ public class OpBindChainTest extends AutomatedTestBase
 				m2 = MatrixBlock.randOperations(n, k, sp[1], 1, 1, "uniform", 7);
 				m1.append(m2, m3, false);
 				m4 = MatrixBlock.randOperations(k, m, sp[1], 1, 1, "uniform", 5);
-				m5 = m1.aggregateBinaryOperations(m3, m4, 
+				m5 = m3.aggregateBinaryOperations(m3, m4, 
 						new MatrixBlock(), InstructionUtils.getMatMultOperator(1));
 				est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity();
 				//System.out.println(est);
@@ -147,7 +147,7 @@ public class OpBindChainTest extends AutomatedTestBase
 				m2 = MatrixBlock.randOperations(m, n, sp[1], 1, 1, "uniform", 7);
 				m1.append(m2, m3, true);
 				m4 = MatrixBlock.randOperations(k+n, m, sp[1], 1, 1, "uniform", 5);
-				m5 = m1.aggregateBinaryOperations(m3, m4, 
+				m5 = m3.aggregateBinaryOperations(m3, m4, 
 						new MatrixBlock(), InstructionUtils.getMatMultOperator(1));
 				est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity();
 				//System.out.println(est);
diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java
index f6410e8..7a76d7a 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java
@@ -129,7 +129,7 @@ public class OpElemWChainTest extends AutomatedTestBase
 			case MULT:
 				bOp = new BinaryOperator(Multiply.getMultiplyFnObject());
 				m1.binaryOperations(bOp, m2, m4);
-				m5 = m1.aggregateBinaryOperations(m4, m3, 
+				m5 = m4.aggregateBinaryOperations(m4, m3, 
 						new MatrixBlock(), InstructionUtils.getMatMultOperator(1));
 				est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity();
 				// System.out.println(m5.getSparsity());
@@ -138,7 +138,7 @@ public class OpElemWChainTest extends AutomatedTestBase
 			case PLUS:
 				bOp = new BinaryOperator(Plus.getPlusFnObject());
 				m1.binaryOperations(bOp, m2, m4);
-				m5 = m1.aggregateBinaryOperations(m4, m3, 
+				m5 = m4.aggregateBinaryOperations(m4, m3, 
 						new MatrixBlock(), InstructionUtils.getMatMultOperator(1));
 				est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity();
 				// System.out.println(m5.getSparsity());
diff --git a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
index e35bc57..25cd99e 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
@@ -132,7 +132,7 @@ public class SquaredProductChainTest extends AutomatedTestBase
 		MatrixBlock m3 = MatrixBlock.randOperations(n, n2, sp[2], 1, 1, "uniform", 3);
 		MatrixBlock m4 = m1.aggregateBinaryOperations(m1, m2, 
 			new MatrixBlock(), InstructionUtils.getMatMultOperator(1));
-		MatrixBlock m5 = m1.aggregateBinaryOperations(m4, m3, 
+		MatrixBlock m5 = m4.aggregateBinaryOperations(m4, m3, 
 			new MatrixBlock(), InstructionUtils.getMatMultOperator(1));
 		
 		//compare estimated and real sparsity