You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/08/26 20:24:21 UTC
[systemds] branch master updated: [SYSTEMDS-3105] CLA Left MM
shared common element sum
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 217e31d [SYSTEMDS-3105] CLA Left MM shared common element sum
217e31d is described below
commit 217e31d87b2ffc745c8a8e1d56cee80e7013aa35
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Thu Aug 26 14:20:12 2021 +0200
[SYSTEMDS-3105] CLA Left MM shared common element sum
This commit modify the Left Matrix Multiplication to share a
common element sum of all SDC Column groups allowing them to skip
all their common elements.
This improve performance of LMM on InfiniMnist by 5-10x depending
on number of rows on the left hand uncompressed matrix
Closes #1376
---
.../runtime/compress/colgroup/ColGroupFactory.java | 39 +++--
.../runtime/compress/colgroup/ColGroupSDC.java | 4 +-
.../runtime/compress/colgroup/ColGroupValue.java | 1 +
.../runtime/compress/colgroup/offset/AOffset.java | 7 +-
.../runtime/compress/lib/CLALibLeftMultBy.java | 163 +++++++++++++++------
.../sysds/runtime/matrix/data/MatrixBlock.java | 14 +-
6 files changed, 163 insertions(+), 65 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index 1d7b97d..f42e382 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -45,6 +45,7 @@ import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
+import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory.CostType;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorExact;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
@@ -214,34 +215,40 @@ public final class ColGroupFactory {
CompressedSizeInfoColGroup cg, int[] colIndexes) {
try {
final int nrUniqueEstimate = cg.getNumVals();
- final CompressionType estimatedBestCompressionType = cg.getBestCompressionType();
+ CompressionType estimatedBestCompressionType = cg.getBestCompressionType();
+
+ if(estimatedBestCompressionType == CompressionType.SDC && cs.costComputationType == CostType.W_TREE) {
+ if(cg.getCompressionSize(CompressionType.DDC) * 3 < cg.getCompressionSize(CompressionType.SDC))
+ estimatedBestCompressionType = CompressionType.DDC;
+ }
+
if(estimatedBestCompressionType == CompressionType.UNCOMPRESSED) {
// shortcut if uncompressed
return new ColGroupUncompressed(colIndexes, in, cs.transposed);
}
else if(estimatedBestCompressionType == CompressionType.SDC && colIndexes.length == 1 &&
in.isInSparseFormat() && cs.transposed) {
- // shortcut for creating SDC!
- // throw new NotImplementedException();
+
return compressSDCZero(in.getSparseBlock(), colIndexes, in.getNumColumns(),
tmp.getDblCountMap(nrUniqueEstimate));
}
else {
- ABitmap ubm;
- if(colIndexes.length > 1)
- ubm = BitmapEncoder.extractBitmapMultiColumns(colIndexes, in, cs.transposed,
- tmp.getDblArrayMap(nrUniqueEstimate));
- else
- ubm = BitmapEncoder.extractBitmap(colIndexes, in, cs.transposed, nrUniqueEstimate);
-
- CompressedSizeEstimator estimator = new CompressedSizeEstimatorExact(in, cs);
+ final int numRows = cs.transposed ? in.getNumColumns() : in.getNumRows();
- CompressedSizeInfoColGroup sizeInfo = new CompressedSizeInfoColGroup(
- estimator.estimateCompressedColGroupSize(ubm, colIndexes), cs.validCompressions, ubm);
+ if(colIndexes.length > 1) {
+ final ABitmap ubm = BitmapEncoder.extractBitmapMultiColumns(colIndexes, in, cs.transposed,
+ tmp.getDblArrayMap(nrUniqueEstimate));
+ CompressedSizeEstimator estimator = new CompressedSizeEstimatorExact(in, cs);
+ CompressedSizeInfoColGroup sizeInfo = new CompressedSizeInfoColGroup(
+ estimator.estimateCompressedColGroupSize(ubm, colIndexes), cs.validCompressions, ubm);
+ return compress(colIndexes, numRows, ubm, estimatedBestCompressionType, cs, in,
+ sizeInfo.getTupleSparsity());
+ }
+ else {
+ final ABitmap ubm = BitmapEncoder.extractBitmap(colIndexes, in, cs.transposed, nrUniqueEstimate);
+ return compress(colIndexes, numRows, ubm, estimatedBestCompressionType, cs, in, 1.0);
+ }
- int numRows = cs.transposed ? in.getNumColumns() : in.getNumRows();
- return compress(colIndexes, numRows, ubm, sizeInfo.getBestCompressionType(cs), cs, in,
- sizeInfo.getTupleSparsity());
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index 23adae9..ffe38b1 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -280,7 +280,7 @@ public class ColGroupSDC extends ColGroupValue {
final double[] mV = m.getDenseBlockValues();
final double[] preAV = preAgg.getDenseBlockValues();
final int numVals = getNumValues();
- AIterator itStart = _indexes.getIterator(cl);
+ final AIterator itStart = _indexes.getIterator(cl);
AIterator it = null;
for(int rowLeft = rl, offOut = 0; rowLeft < ru; rowLeft++, offOut += numVals) {
final int offLeft = rowLeft * _numRows;
@@ -300,7 +300,7 @@ public class ColGroupSDC extends ColGroupValue {
preAV[def] += mV[offLeft + rc];
}
}
- if(it != null)
+ if(it != null && cu < m.getNumColumns())
_indexes.cacheIterator(it, cu + 1);
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
index 7307bc2..6844095 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
@@ -887,6 +887,7 @@ public abstract class ColGroupValue extends ColGroupCompressed implements Clonea
if(right.length != rightColumns.length)
throw new DMLCompressionException(
"Error right not equal length " + right.length + " " + rightColumns.length);
+
for(int row = 0; row < leftRows.length; row++) {
final int outputRowOffset = leftRows[row] * outCols;
final double vLeft = left[row];
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
index 7f42240..ac359b3 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
@@ -59,8 +59,11 @@ public abstract class AOffset implements Serializable {
public AIterator getIterator(int row) {
if(skipIterators != null) {
Map<Integer, AIterator> sk = skipIterators.get();
- if(sk != null && sk.containsKey(row))
- return sk.get(row).clone();
+ if(sk != null && sk.containsKey(row)){
+ AIterator it = sk.get(row);
+ if(it != null)
+ return it.clone();
+ }
}
AIterator it = getIterator();
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index 04ccdc2..3ca657b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -33,6 +33,8 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingle;
import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.compress.utils.LinearAlgebraUtils;
import org.apache.sysds.runtime.functionobjects.Plus;
@@ -92,7 +94,7 @@ public class CLALibLeftMultBy {
final boolean overlapping = cmb.isOverlapping();
List<AColGroup> groups = cmb.getColGroups();
result.allocateDenseBlock();
-
+
if(overlapping) {
LOG.warn("Inefficient TSMM with overlapping matrix could be implemented multi-threaded but is not yet.");
leftMultByCompressedTransposedMatrix(groups, groups, result);
@@ -110,7 +112,7 @@ public class CLALibLeftMultBy {
final AColGroup g = groups.get(i);
tasks.add(new LeftMultByCompressedTransposedMatrixTask(groups, g, result, i, groups.size()));
}
-
+
for(Future<Object> tret : pool.invokeAll(tasks))
tret.get();
pool.shutdown();
@@ -228,26 +230,56 @@ public class CLALibLeftMultBy {
}
private static MatrixBlock leftMultByMatrix(List<AColGroup> colGroups, MatrixBlock that, MatrixBlock ret, int k,
- int numColumns, boolean overlapping) {
+ int numColumns, boolean overlapping) {
if(that.isEmpty()) {
ret.setNonZeros(0);
return ret;
}
+ boolean containsSDC = false;
+
+ for(AColGroup g : colGroups) {
+ if(g instanceof ColGroupSDC || g instanceof ColGroupSDCSingle)
+ containsSDC = true;
+ }
+
+ final List<AColGroup> filteredGroups = containsSDC ? new ArrayList<>() : colGroups;
+ // a constant colgroup summing the default values.
+ final double[] constV = containsSDC ? new double[numColumns] : null;
+
+ if(containsSDC) {
+ for(AColGroup g : colGroups) {
+ if(g instanceof ColGroupSDC)
+ filteredGroups.add(((ColGroupSDC) g).extractCommon(constV));
+ else if(g instanceof ColGroupSDCSingle)
+ filteredGroups.add(((ColGroupSDCSingle) g).extractCommon(constV));
+ else
+ filteredGroups.add(g);
+ }
+ }
+
ret.allocateDenseBlock();
- if(k == 1)
- leftMultByMatrixPrimitive(colGroups, that, ret, numColumns, 0, that.getNumRows());
+ if(k == 1) {
+ leftMultByMatrixPrimitive(filteredGroups, that, ret, numColumns, 0, that.getNumRows());
+ if(containsSDC) {
+ MatrixBlock rowSum = that.rowSum();
+ if(rowSum.isInSparseFormat())
+ rowSum.sparseToDense();
+ double[] rowSums = rowSum.getDenseBlockValues();
+ outerProduct(rowSums, constV, ret.getDenseBlockValues());
+ }
+ }
else {
try {
final ExecutorService pool = CommonThreadPool.get(k);
final ArrayList<Callable<MatrixBlock>> tasks = new ArrayList<>();
final int rowBlockSize = that.getNumRows() < 8 ? 1 : Math.min(Math.max(that.getNumRows() / k, 1), 8);
- // final int rowBlockSize = 4;
+ double[] rowSums = null;
if(overlapping) {
- for(AColGroup g : colGroups) {
+ for(AColGroup g : filteredGroups) {
MatrixBlock tmpRet = new MatrixBlock(ret.getNumRows(), ret.getNumColumns(), false);
tmpRet.allocateDenseBlock();
for(int blo = 0; blo < that.getNumRows(); blo += rowBlockSize)
@@ -256,6 +288,12 @@ public class CLALibLeftMultBy {
}
List<Future<MatrixBlock>> futures = pool.invokeAll(tasks);
+ if(containsSDC) {
+ MatrixBlock rowSum = that.rowSum();
+ if(rowSum.isInSparseFormat())
+ rowSum.sparseToDense();
+ rowSums = rowSum.getDenseBlockValues();
+ }
pool.shutdown();
BinaryOperator op = new BinaryOperator(Plus.getPlusFnObject());
for(Future<MatrixBlock> future : futures)
@@ -264,36 +302,40 @@ public class CLALibLeftMultBy {
else {
if(rowBlockSize > 2) {
for(int blo = 0; blo < that.getNumRows(); blo += rowBlockSize) {
- tasks.add(new LeftMatrixColGroupMultTaskNew(colGroups, that, ret, numColumns, blo,
+ tasks.add(new LeftMatrixColGroupMultTaskNew(filteredGroups, that, ret, numColumns, blo,
Math.min(blo + rowBlockSize, that.getNumRows())));
}
}
else {
-
- List<List<AColGroup>> split = split(colGroups, Math.max(k / that.getNumRows(), 1));
+ List<List<AColGroup>> split = split(filteredGroups, Math.max(k / that.getNumRows(), 1));
for(int blo = 0; blo < that.getNumRows(); blo += rowBlockSize) {
for(List<AColGroup> gr : split)
tasks.add(new LeftMatrixColGroupMultTaskNew(gr, that, ret, numColumns, blo,
Math.min(blo + rowBlockSize, that.getNumRows())));
}
-
- // for(AColGroup g : colGroups)
- // for(int blo = 0; blo < that.getNumRows(); blo += rowBlockSize)
- // tasks.add(new LeftMatrixColGroupMultTaskOld(g, that, ret, blo,
- // Math.min(blo + rowBlockSize, that.getNumRows()), maxNumValues));
}
List<Future<MatrixBlock>> futures = pool.invokeAll(tasks);
+ if(containsSDC) {
+ MatrixBlock rowSum = that.rowSum();
+ if(rowSum.isInSparseFormat())
+ rowSum.sparseToDense();
+ rowSums = rowSum.getDenseBlockValues();
+ }
pool.shutdown();
for(Future<MatrixBlock> future : futures)
future.get();
}
+ if(containsSDC)
+ outerProduct(rowSums, constV, ret.getDenseBlockValues());
+
}
catch(InterruptedException | ExecutionException e) {
throw new DMLRuntimeException(e);
}
}
+
ret.recomputeNonZeros();
return ret;
}
@@ -311,6 +353,16 @@ public class CLALibLeftMultBy {
return ret;
}
+ private static void outerProduct(final double[] leftRowSum, final double[] rightColumnSum, final double[] result) {
+ for(int row = 0; row < leftRowSum.length; row++) {
+ final int offOut = rightColumnSum.length * row;
+ final double vLeft = leftRowSum[row];
+ for(int col = 0; col < rightColumnSum.length; col++) {
+ result[offOut + col] += vLeft * rightColumnSum[col];
+ }
+ }
+ }
+
private static class LeftMatrixColGroupMultTaskOld implements Callable<MatrixBlock> {
private final AColGroup _group;
private final MatrixBlock _that;
@@ -379,47 +431,44 @@ public class CLALibLeftMultBy {
}
}
else {
- List<ColGroupValue> v = new ArrayList<>();
- int rowBlockSize = 1;
- List<MatrixBlock> preAgg = new ArrayList<>();
- int colGroupBlocking = 16;
- for(int j = 0; j < colGroupBlocking; j++) {
- MatrixBlock m = new MatrixBlock(1, 1, false);
- m.allocateDenseBlock();
- preAgg.add(m);
- }
+ // The number of rows to process together
+ final int rowBlockSize = 1;
+ // The number of column groups to process together
+ final int colGroupBlocking = 16;
+
+ // Allocate pre Aggregate Array List
+ final List<MatrixBlock> preAgg = populatePreAggregate(colGroupBlocking);
+ // Allocate a ColGroupValue array for the Column Groups of Value Type.
+ final List<ColGroupValue> ColGroupValues = preFilterAndMultiply(colGroups, that, ret, numColumns, rl, ru);
+ // Allocate temporary Result matrix.
MatrixBlock tmpRes = new MatrixBlock(rowBlockSize, numColumns, false);
- for(int j = 0; j < colGroups.size(); j++) {
- AColGroup a = colGroups.get(j);
- if(a instanceof ColGroupValue) {
- ColGroupValue av = (ColGroupValue) a;
- v.add(av);
- }
- else
- a.leftMultByMatrix(that, ret, rl, ru);
- }
- Collections.sort(v, Comparator.comparing(AColGroup::getNumValues).reversed());
- // LOG.error(v);
- for(int g = 0; g < v.size(); g += colGroupBlocking) {
+ for(int g = 0; g < ColGroupValues.size(); g += colGroupBlocking) {
final int gEnd = Math.min(g + colGroupBlocking, colGroups.size());
- for(int j = g; j < gEnd && j < v.size(); j++) {
- ColGroupValue cg = v.get(j);
- preAgg.get(j % colGroupBlocking).reset(rowBlockSize, cg.getNumValues(), false);
+
+ // for each column group in the current block allocate the preaggregate array.
+ for(int j = g; j < gEnd && j < ColGroupValues.size(); j++) {
+ ColGroupValue cg = ColGroupValues.get(j);
+ int nVals = cg.getNumValues();
+ preAgg.get(j % colGroupBlocking).reset(rowBlockSize, nVals, false);
}
+
// int colBlockSize = 16000;
int colBlockSize = 64000;
+ // For each row block
for(int h = rl; h < ru; h += rowBlockSize) {
+ // For each column block
for(int i = 0; i < that.getNumColumns(); i += colBlockSize) {
- for(int j = g; j < gEnd && j < v.size(); j++) {
- v.get(j).preAggregateDense(that, preAgg.get(j % colGroupBlocking), h,
+ // Pre Aggregate each column group in block
+ for(int j = g; j < gEnd && j < ColGroupValues.size(); j++) {
+ ColGroupValues.get(j).preAggregateDense(that, preAgg.get(j % colGroupBlocking), h,
Math.min(h + rowBlockSize, ru), i, Math.min(i + colBlockSize, that.getNumColumns()));
}
}
- for(int j = g; j < gEnd && j < v.size(); j++) {
- ColGroupValue vj = v.get(j);
+ for(int j = g; j < gEnd && j < ColGroupValues.size(); j++) {
+ ColGroupValue vj = ColGroupValues.get(j);
MatrixBlock preAggJ = preAgg.get(j % colGroupBlocking);
preAggJ.recomputeNonZeros();
tmpRes.reset(rowBlockSize, vj.getNumCols(), false);
@@ -431,4 +480,32 @@ public class CLALibLeftMultBy {
}
}
}
+
+ private static List<MatrixBlock> populatePreAggregate(int colGroupBlocking) {
+ final List<MatrixBlock> preAgg = new ArrayList<>();
+ // poplate the preAgg array.
+ for(int j = 0; j < colGroupBlocking; j++) {
+
+ MatrixBlock m = new MatrixBlock(1, 1, false);
+ m.allocateDenseBlock();
+ preAgg.add(m);
+ }
+ return preAgg;
+ }
+
+ private static List<ColGroupValue> preFilterAndMultiply(List<AColGroup> colGroups, MatrixBlock that,
+ MatrixBlock ret, int numColumns, int rl, int ru) {
+ final List<ColGroupValue> ColGroupValues = new ArrayList<>();
+ for(int j = 0; j < colGroups.size(); j++) {
+ AColGroup a = colGroups.get(j);
+ if(a instanceof ColGroupValue) {
+ ColGroupValue av = (ColGroupValue) a;
+ ColGroupValues.add(av);
+ }
+ else
+ a.leftMultByMatrix(that, ret, rl, ru);
+ }
+ Collections.sort(ColGroupValues, Comparator.comparing(AColGroup::getNumValues).reversed());
+ return ColGroupValues;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 8d2d672..160f8e9 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -975,13 +975,23 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
}
/**
- * Wrapper method for reduceall-colSum of a matrix.
+ * Wrapper method for single threaded reduceall-colSum of a matrix.
*
* @return A new MatrixBlock containing the column sums of this matrix.
*/
public MatrixBlock colSum() {
AggregateUnaryOperator op = InstructionUtils.parseBasicAggregateUnaryOperator("uack+", 1);
- return aggregateUnaryOperations(op, null, 1000, null);
+ return aggregateUnaryOperations(op, null, 1000, null, true);
+ }
+
+ /**
+ * Wrapper method for single threaded reduceall-rowSum of a matrix.
+ *
+ * @return A new MatrixBlock containing the row sums of this matrix.
+ */
+ public MatrixBlock rowSum(){
+ AggregateUnaryOperator op = InstructionUtils.parseBasicAggregateUnaryOperator("uark+", 1);
+ return aggregateUnaryOperations(op, null, 1000, null, true);
}
/**