You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2022/01/16 13:06:47 UTC
[systemds] 02/03: [SYSTEMDS-3243] Compressed Matrix Multiplication part
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
commit a98560cb306012771dce215bda150a89dd9bf482
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Sun Jan 16 14:02:08 2022 +0100
[SYSTEMDS-3243] Compressed Matrix Multiplication part
This commit follow the previous by modifying the compression tests
and compression path for Matrix Multiplcation to fit with the design of
the normal MatrixBlock.
Closes #1480
---
.../runtime/compress/CompressedMatrixBlock.java | 105 +----------------
.../runtime/compress/lib/CLALibMatrixMult.java | 128 +++++++++++++++++++++
.../component/compress/CompressedTestBase.java | 3 +-
.../test/component/estim/OpBindChainTest.java | 4 +-
.../test/component/estim/OpElemWChainTest.java | 4 +-
.../component/estim/SquaredProductChainTest.java | 2 +-
6 files changed, 140 insertions(+), 106 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index 85cc23b..ec09226 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -51,8 +51,8 @@ import org.apache.sysds.runtime.compress.lib.CLALibCompAgg;
import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy;
import org.apache.sysds.runtime.compress.lib.CLALibMMChain;
+import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult;
import org.apache.sysds.runtime.compress.lib.CLALibReExpand;
-import org.apache.sysds.runtime.compress.lib.CLALibRightMultBy;
import org.apache.sysds.runtime.compress.lib.CLALibScalar;
import org.apache.sysds.runtime.compress.lib.CLALibSlice;
import org.apache.sysds.runtime.compress.lib.CLALibSquash;
@@ -61,13 +61,11 @@ import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
-import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseRow;
import org.apache.sysds.runtime.functionobjects.MinusMultiply;
import org.apache.sysds.runtime.functionobjects.PlusMultiply;
-import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
@@ -76,7 +74,6 @@ import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.CTableMap;
import org.apache.sysds.runtime.matrix.data.IJV;
import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen;
-import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.LibMatrixTercell;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
@@ -471,105 +468,15 @@ public class CompressedMatrixBlock extends MatrixBlock {
}
@Override
- public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret,
- AggregateBinaryOperator op) {
- // create output matrix block
- return aggregateBinaryOperations(m1, m2, ret, op, false, false);
+ public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
+ checkAggregateBinaryOperations(m1, m2, op);
+ return CLALibMatrixMult.matrixMultiply(m1, m2, ret, op.getNumThreads(), false, false);
}
public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret,
AggregateBinaryOperator op, boolean transposeLeft, boolean transposeRight) {
- validateMatrixMult(m1, m2);
- final int k = op.getNumThreads();
- final Timing time = LOG.isTraceEnabled() ? new Timing(true) : null;
-
- if(m1 instanceof CompressedMatrixBlock && m2 instanceof CompressedMatrixBlock) {
- return doubleCompressedAggregateBinaryOperations((CompressedMatrixBlock) m1, (CompressedMatrixBlock) m2, ret,
- op, transposeLeft, transposeRight);
- }
- boolean transposeOutput = false;
- if(transposeLeft || transposeRight) {
-
- if((m1 instanceof CompressedMatrixBlock && transposeLeft) ||
- (m2 instanceof CompressedMatrixBlock && transposeRight)) {
- // change operation from m1 %*% m2 -> t( t(m2) %*% t(m1) )
- transposeOutput = true;
- MatrixBlock tmp = m1;
- m1 = m2;
- m2 = tmp;
- boolean tmpLeft = transposeLeft;
- transposeLeft = !transposeRight;
- transposeRight = !tmpLeft;
-
- }
-
- if(!(m1 instanceof CompressedMatrixBlock) && transposeLeft) {
- m1 = LibMatrixReorg.transpose(m1, k);
- transposeLeft = false;
- }
- else if(!(m2 instanceof CompressedMatrixBlock) && transposeRight) {
- m2 = LibMatrixReorg.transpose(m2, k);
- transposeRight = false;
- }
- }
-
- final boolean right = (m1 == this);
- final MatrixBlock that = right ? m2 : m1;
-
- // create output matrix block
- if(right)
- ret = CLALibRightMultBy.rightMultByMatrix(this, that, ret, op.getNumThreads());
- else
- ret = CLALibLeftMultBy.leftMultByMatrix(this, that, ret, op.getNumThreads());
-
- if(LOG.isTraceEnabled())
- LOG.trace("MM: Time block w/ sharedDim: " + m1.getNumColumns() + " rowLeft: " + m1.getNumRows() + " colRight:"
- + m2.getNumColumns() + " in " + time.stop() + "ms.");
-
- if(transposeOutput) {
- if(ret instanceof CompressedMatrixBlock) {
- LOG.warn("Transposing decompression");
- ret = ((CompressedMatrixBlock) ret).decompress(k);
- }
- ret = LibMatrixReorg.transpose(ret, k);
- }
-
- return ret;
- }
-
- private void validateMatrixMult(MatrixBlock m1, MatrixBlock m2) {
- if(!(m1 == this || m2 == this))
- throw new DMLRuntimeException("Invalid aggregateBinaryOperation One of either input should be this");
- }
-
- private MatrixBlock doubleCompressedAggregateBinaryOperations(CompressedMatrixBlock m1, CompressedMatrixBlock m2,
- MatrixBlock ret, AggregateBinaryOperator op, boolean transposeLeft, boolean transposeRight) {
- if(!transposeLeft && !transposeRight) {
- // If both are not transposed, decompress the right hand side. to enable
- // compressed overlapping output.
- LOG.warn("Matrix decompression from multiplying two compressed matrices.");
- return aggregateBinaryOperations(m1, getUncompressed(m2), ret, op, transposeLeft, transposeRight);
- }
- else if(transposeLeft && !transposeRight) {
- if(m1.getNumColumns() > m2.getNumColumns()) {
- ret = CLALibLeftMultBy.leftMultByMatrixTransposed(m1, m2, ret, op.getNumThreads());
- ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), op.getNumThreads());
- return ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
- }
- else
- return CLALibLeftMultBy.leftMultByMatrixTransposed(m2, m1, ret, op.getNumThreads());
-
- }
- else if(!transposeLeft && transposeRight) {
- throw new DMLCompressionException("Not Implemented compressed Matrix Mult, to produce larger matrix");
- // worst situation since it blows up the result matrix in number of rows in
- // either compressed matrix.
- }
- else {
- ret = aggregateBinaryOperations(m2, m1, ret, op);
- ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), op.getNumThreads());
- return ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
- }
+ checkAggregateBinaryOperations(m1, m2, op, transposeLeft, transposeRight);
+ return CLALibMatrixMult.matrixMultiply(m1, m2, ret, op.getNumThreads(), transposeLeft, transposeRight);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java
new file mode 100644
index 0000000..941338d
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.lib;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.functionobjects.SwapIndex;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+
+public class CLALibMatrixMult {
+ private static final Log LOG = LogFactory.getLog(CLALibMatrixMult.class.getName());
+
+ public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
+ return matrixMultiply(m1, m2, ret, k, false, false);
+ }
+
+ public static MatrixBlock matrixMultiply(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret,
+ int k, boolean transposeLeft, boolean transposeRight) {
+ final Timing time = LOG.isTraceEnabled() ? new Timing(true) : null;
+
+ if(m1 instanceof CompressedMatrixBlock && m2 instanceof CompressedMatrixBlock) {
+ return doubleCompressedMatrixMultiply((CompressedMatrixBlock) m1, (CompressedMatrixBlock) m2, ret,
+ k, transposeLeft, transposeRight);
+ }
+
+ boolean transposeOutput = false;
+ if(transposeLeft || transposeRight) {
+
+ if((m1 instanceof CompressedMatrixBlock && transposeLeft) ||
+ (m2 instanceof CompressedMatrixBlock && transposeRight)) {
+ // change operation from m1 %*% m2 -> t( t(m2) %*% t(m1) )
+ transposeOutput = true;
+ MatrixBlock tmp = m1;
+ m1 = m2;
+ m2 = tmp;
+ boolean tmpLeft = transposeLeft;
+ transposeLeft = !transposeRight;
+ transposeRight = !tmpLeft;
+ }
+
+ if(!(m1 instanceof CompressedMatrixBlock) && transposeLeft) {
+ m1 = LibMatrixReorg.transpose(m1, k);
+ transposeLeft = false;
+ }
+ else if(!(m2 instanceof CompressedMatrixBlock) && transposeRight) {
+ m2 = LibMatrixReorg.transpose(m2, k);
+ transposeRight = false;
+ }
+ }
+
+ final boolean right = (m1 instanceof CompressedMatrixBlock);
+ final CompressedMatrixBlock c =(CompressedMatrixBlock) (right ? m1 : m2);
+ final MatrixBlock that = right ? m2 : m1;
+
+ // create output matrix block
+ if(right)
+ ret = CLALibRightMultBy.rightMultByMatrix(c, that, ret, k);
+ else
+ ret = CLALibLeftMultBy.leftMultByMatrix(c, that, ret, k);
+
+ if(LOG.isTraceEnabled())
+ LOG.trace("MM: Time block w/ sharedDim: " + m1.getNumColumns() + " rowLeft: " + m1.getNumRows() + " colRight:"
+ + m2.getNumColumns() + " in " + time.stop() + "ms.");
+
+ if(transposeOutput) {
+ if(ret instanceof CompressedMatrixBlock) {
+ LOG.warn("Transposing decompression");
+ ret = ((CompressedMatrixBlock) ret).decompress(k);
+ }
+ ret = LibMatrixReorg.transpose(ret, k);
+ }
+
+ return ret;
+ }
+
+ private static MatrixBlock doubleCompressedMatrixMultiply(CompressedMatrixBlock m1, CompressedMatrixBlock m2,
+ MatrixBlock ret, int k, boolean transposeLeft, boolean transposeRight) {
+ if(!transposeLeft && !transposeRight) {
+ // If both are not transposed, decompress the right hand side. to enable
+ // compressed overlapping output.
+ LOG.warn("Matrix decompression from multiplying two compressed matrices.");
+ return matrixMultiply(m1, CompressedMatrixBlock.getUncompressed(m2), ret, k, transposeLeft, transposeRight);
+ }
+ else if(transposeLeft && !transposeRight) {
+ if(m1.getNumColumns() > m2.getNumColumns()) {
+ ret = CLALibLeftMultBy.leftMultByMatrixTransposed(m1, m2, ret, k);
+ ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k);
+ return ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
+ }
+ else
+ return CLALibLeftMultBy.leftMultByMatrixTransposed(m2, m1, ret, k);
+
+ }
+ else if(!transposeLeft && transposeRight) {
+ throw new DMLCompressionException("Not Implemented compressed Matrix Mult, to produce larger matrix");
+ // worst situation since it blows up the result matrix in number of rows in
+ // either compressed matrix.
+ }
+ else {
+ ret = CLALibMatrixMult.matrixMult(m2, m1, ret, k);
+ ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k);
+ return ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
+ }
+ }
+
+}
diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
index d310e27..5d201c0 100644
--- a/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedTestBase.java
@@ -621,7 +621,6 @@ public abstract class CompressedTestBase extends TestBase {
return; // Early termination since the test does not test what we wanted.
// Make Operator
- AggregateBinaryOperator abop = InstructionUtils.getMatMultOperator(_k);
AggregateBinaryOperator abopSingle = InstructionUtils.getMatMultOperator(1);
// vector-matrix uncompressed
@@ -633,7 +632,7 @@ public abstract class CompressedTestBase extends TestBase {
ucRet = right.aggregateBinaryOperations(left, right, ucRet, abopSingle);
MatrixBlock ret2 = ((CompressedMatrixBlock) cmb).aggregateBinaryOperations(compMatrix, cmb, new MatrixBlock(),
- abop, transposeLeft, transposeRight);
+ abopSingle, transposeLeft, transposeRight);
compareResultMatrices(ucRet, ret2, 100);
}
diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java
index f75ba17..12c66cf 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/OpBindChainTest.java
@@ -136,7 +136,7 @@ public class OpBindChainTest extends AutomatedTestBase
m2 = MatrixBlock.randOperations(n, k, sp[1], 1, 1, "uniform", 7);
m1.append(m2, m3, false);
m4 = MatrixBlock.randOperations(k, m, sp[1], 1, 1, "uniform", 5);
- m5 = m1.aggregateBinaryOperations(m3, m4,
+ m5 = m3.aggregateBinaryOperations(m3, m4,
new MatrixBlock(), InstructionUtils.getMatMultOperator(1));
est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity();
//System.out.println(est);
@@ -147,7 +147,7 @@ public class OpBindChainTest extends AutomatedTestBase
m2 = MatrixBlock.randOperations(m, n, sp[1], 1, 1, "uniform", 7);
m1.append(m2, m3, true);
m4 = MatrixBlock.randOperations(k+n, m, sp[1], 1, 1, "uniform", 5);
- m5 = m1.aggregateBinaryOperations(m3, m4,
+ m5 = m3.aggregateBinaryOperations(m3, m4,
new MatrixBlock(), InstructionUtils.getMatMultOperator(1));
est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m4), OpCode.MM)).getSparsity();
//System.out.println(est);
diff --git a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java
index f6410e8..7a76d7a 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/OpElemWChainTest.java
@@ -129,7 +129,7 @@ public class OpElemWChainTest extends AutomatedTestBase
case MULT:
bOp = new BinaryOperator(Multiply.getMultiplyFnObject());
m1.binaryOperations(bOp, m2, m4);
- m5 = m1.aggregateBinaryOperations(m4, m3,
+ m5 = m4.aggregateBinaryOperations(m4, m3,
new MatrixBlock(), InstructionUtils.getMatMultOperator(1));
est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity();
// System.out.println(m5.getSparsity());
@@ -138,7 +138,7 @@ public class OpElemWChainTest extends AutomatedTestBase
case PLUS:
bOp = new BinaryOperator(Plus.getPlusFnObject());
m1.binaryOperations(bOp, m2, m4);
- m5 = m1.aggregateBinaryOperations(m4, m3,
+ m5 = m4.aggregateBinaryOperations(m4, m3,
new MatrixBlock(), InstructionUtils.getMatMultOperator(1));
est = estim.estim(new MMNode(new MMNode(new MMNode(m1), new MMNode(m2), op), new MMNode(m3), OpCode.MM)).getSparsity();
// System.out.println(m5.getSparsity());
diff --git a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
index e35bc57..25cd99e 100644
--- a/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
+++ b/src/test/java/org/apache/sysds/test/component/estim/SquaredProductChainTest.java
@@ -132,7 +132,7 @@ public class SquaredProductChainTest extends AutomatedTestBase
MatrixBlock m3 = MatrixBlock.randOperations(n, n2, sp[2], 1, 1, "uniform", 3);
MatrixBlock m4 = m1.aggregateBinaryOperations(m1, m2,
new MatrixBlock(), InstructionUtils.getMatMultOperator(1));
- MatrixBlock m5 = m1.aggregateBinaryOperations(m4, m3,
+ MatrixBlock m5 = m4.aggregateBinaryOperations(m4, m3,
new MatrixBlock(), InstructionUtils.getMatMultOperator(1));
//compare estimated and real sparsity