You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/08/26 20:24:21 UTC

[systemds] branch master updated: [SYSTEMDS-3105] CLA Left MM shared common element sum

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 217e31d  [SYSTEMDS-3105] CLA Left MM shared common element sum
217e31d is described below

commit 217e31d87b2ffc745c8a8e1d56cee80e7013aa35
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Thu Aug 26 14:20:12 2021 +0200

    [SYSTEMDS-3105] CLA Left MM shared common element sum
    
    This commit modify the Left Matrix Multiplication to share a
    common element sum of all SDC Column groups allowing them to skip
    all their common elements.
    
    This improve performance of LMM on InfiniMnist by 5-10x depending
    on number of rows on the left hand uncompressed matrix
    
    Closes #1376
---
 .../runtime/compress/colgroup/ColGroupFactory.java |  39 +++--
 .../runtime/compress/colgroup/ColGroupSDC.java     |   4 +-
 .../runtime/compress/colgroup/ColGroupValue.java   |   1 +
 .../runtime/compress/colgroup/offset/AOffset.java  |   7 +-
 .../runtime/compress/lib/CLALibLeftMultBy.java     | 163 +++++++++++++++------
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  14 +-
 6 files changed, 163 insertions(+), 65 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index 1d7b97d..f42e382 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -45,6 +45,7 @@ import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
 import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
 import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
 import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
+import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory.CostType;
 import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator;
 import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorExact;
 import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
@@ -214,34 +215,40 @@ public final class ColGroupFactory {
 		CompressedSizeInfoColGroup cg, int[] colIndexes) {
 		try {
 			final int nrUniqueEstimate = cg.getNumVals();
-			final CompressionType estimatedBestCompressionType = cg.getBestCompressionType();
+			CompressionType estimatedBestCompressionType = cg.getBestCompressionType();
+			
+			if(estimatedBestCompressionType == CompressionType.SDC && cs.costComputationType == CostType.W_TREE) {
+				if(cg.getCompressionSize(CompressionType.DDC) * 3 < cg.getCompressionSize(CompressionType.SDC))
+					estimatedBestCompressionType = CompressionType.DDC;
+			}
+
 			if(estimatedBestCompressionType == CompressionType.UNCOMPRESSED) {
 				// shortcut if uncompressed
 				return new ColGroupUncompressed(colIndexes, in, cs.transposed);
 			}
 			else if(estimatedBestCompressionType == CompressionType.SDC && colIndexes.length == 1 &&
 				in.isInSparseFormat() && cs.transposed) {
-				// shortcut for creating SDC!
-				// throw new NotImplementedException();
+
 				return compressSDCZero(in.getSparseBlock(), colIndexes, in.getNumColumns(),
 					tmp.getDblCountMap(nrUniqueEstimate));
 			}
 			else {
-				ABitmap ubm;
-				if(colIndexes.length > 1)
-					ubm = BitmapEncoder.extractBitmapMultiColumns(colIndexes, in, cs.transposed,
-						tmp.getDblArrayMap(nrUniqueEstimate));
-				else
-					ubm = BitmapEncoder.extractBitmap(colIndexes, in, cs.transposed, nrUniqueEstimate);
-
-				CompressedSizeEstimator estimator = new CompressedSizeEstimatorExact(in, cs);
+				final int numRows = cs.transposed ? in.getNumColumns() : in.getNumRows();
 
-				CompressedSizeInfoColGroup sizeInfo = new CompressedSizeInfoColGroup(
-					estimator.estimateCompressedColGroupSize(ubm, colIndexes), cs.validCompressions, ubm);
+				if(colIndexes.length > 1) {
+					final ABitmap ubm = BitmapEncoder.extractBitmapMultiColumns(colIndexes, in, cs.transposed,
+						tmp.getDblArrayMap(nrUniqueEstimate));
+					CompressedSizeEstimator estimator = new CompressedSizeEstimatorExact(in, cs);
+					CompressedSizeInfoColGroup sizeInfo = new CompressedSizeInfoColGroup(
+						estimator.estimateCompressedColGroupSize(ubm, colIndexes), cs.validCompressions, ubm);
+					return compress(colIndexes, numRows, ubm, estimatedBestCompressionType, cs, in,
+						sizeInfo.getTupleSparsity());
+				}
+				else {
+					final ABitmap ubm = BitmapEncoder.extractBitmap(colIndexes, in, cs.transposed, nrUniqueEstimate);
+					return compress(colIndexes, numRows, ubm, estimatedBestCompressionType, cs, in, 1.0);
+				}
 
-				int numRows = cs.transposed ? in.getNumColumns() : in.getNumRows();
-				return compress(colIndexes, numRows, ubm, sizeInfo.getBestCompressionType(cs), cs, in,
-					sizeInfo.getTupleSparsity());
 			}
 
 		}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index 23adae9..ffe38b1 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -280,7 +280,7 @@ public class ColGroupSDC extends ColGroupValue {
 		final double[] mV = m.getDenseBlockValues();
 		final double[] preAV = preAgg.getDenseBlockValues();
 		final int numVals = getNumValues();
-		AIterator itStart = _indexes.getIterator(cl);
+		final AIterator itStart = _indexes.getIterator(cl);
 		AIterator it = null;
 		for(int rowLeft = rl, offOut = 0; rowLeft < ru; rowLeft++, offOut += numVals) {
 			final int offLeft = rowLeft * _numRows;
@@ -300,7 +300,7 @@ public class ColGroupSDC extends ColGroupValue {
 				preAV[def] += mV[offLeft + rc];
 			}
 		}
-		if(it != null)
+		if(it != null && cu < m.getNumColumns())
 			_indexes.cacheIterator(it, cu + 1);
 	}
 
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
index 7307bc2..6844095 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
@@ -887,6 +887,7 @@ public abstract class ColGroupValue extends ColGroupCompressed implements Clonea
 		if(right.length != rightColumns.length)
 			throw new DMLCompressionException(
 				"Error right not equal length " + right.length + " " + rightColumns.length);
+				
 		for(int row = 0; row < leftRows.length; row++) {
 			final int outputRowOffset = leftRows[row] * outCols;
 			final double vLeft = left[row];
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
index 7f42240..ac359b3 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java
@@ -59,8 +59,11 @@ public abstract class AOffset implements Serializable {
 	public AIterator getIterator(int row) {
 		if(skipIterators != null) {
 			Map<Integer, AIterator> sk = skipIterators.get();
-			if(sk != null && sk.containsKey(row))
-				return sk.get(row).clone();
+			if(sk != null && sk.containsKey(row)){
+				AIterator it = sk.get(row);
+				if(it != null)
+					return it.clone();
+			}
 		}
 
 		AIterator it = getIterator();
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index 04ccdc2..3ca657b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -33,6 +33,8 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingle;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
 import org.apache.sysds.runtime.compress.utils.LinearAlgebraUtils;
 import org.apache.sysds.runtime.functionobjects.Plus;
@@ -92,7 +94,7 @@ public class CLALibLeftMultBy {
 		final boolean overlapping = cmb.isOverlapping();
 		List<AColGroup> groups = cmb.getColGroups();
 		result.allocateDenseBlock();
-		
+
 		if(overlapping) {
 			LOG.warn("Inefficient TSMM with overlapping matrix could be implemented multi-threaded but is not yet.");
 			leftMultByCompressedTransposedMatrix(groups, groups, result);
@@ -110,7 +112,7 @@ public class CLALibLeftMultBy {
 					final AColGroup g = groups.get(i);
 					tasks.add(new LeftMultByCompressedTransposedMatrixTask(groups, g, result, i, groups.size()));
 				}
-				
+
 				for(Future<Object> tret : pool.invokeAll(tasks))
 					tret.get();
 				pool.shutdown();
@@ -228,26 +230,56 @@ public class CLALibLeftMultBy {
 	}
 
 	private static MatrixBlock leftMultByMatrix(List<AColGroup> colGroups, MatrixBlock that, MatrixBlock ret, int k,
-		int numColumns,  boolean overlapping) {
+		int numColumns, boolean overlapping) {
 
 		if(that.isEmpty()) {
 			ret.setNonZeros(0);
 			return ret;
 		}
 
+		boolean containsSDC = false;
+
+		for(AColGroup g : colGroups) {
+			if(g instanceof ColGroupSDC || g instanceof ColGroupSDCSingle)
+				containsSDC = true;
+		}
+
+		final List<AColGroup> filteredGroups = containsSDC ? new ArrayList<>() : colGroups;
+		// a constant colgroup summing the default values.
+		final double[] constV = containsSDC ? new double[numColumns] : null;
+
+		if(containsSDC) {
+			for(AColGroup g : colGroups) {
+				if(g instanceof ColGroupSDC)
+					filteredGroups.add(((ColGroupSDC) g).extractCommon(constV));
+				else if(g instanceof ColGroupSDCSingle)
+					filteredGroups.add(((ColGroupSDCSingle) g).extractCommon(constV));
+				else
+					filteredGroups.add(g);
+			}
+		}
+
 		ret.allocateDenseBlock();
 
-		if(k == 1)
-			leftMultByMatrixPrimitive(colGroups, that, ret, numColumns, 0, that.getNumRows());
+		if(k == 1) {
+			leftMultByMatrixPrimitive(filteredGroups, that, ret, numColumns, 0, that.getNumRows());
+			if(containsSDC) {
+				MatrixBlock rowSum = that.rowSum();
+				if(rowSum.isInSparseFormat())
+					rowSum.sparseToDense();
+				double[] rowSums = rowSum.getDenseBlockValues();
+				outerProduct(rowSums, constV, ret.getDenseBlockValues());
+			}
+		}
 		else {
 			try {
 				final ExecutorService pool = CommonThreadPool.get(k);
 				final ArrayList<Callable<MatrixBlock>> tasks = new ArrayList<>();
 				final int rowBlockSize = that.getNumRows() < 8 ? 1 : Math.min(Math.max(that.getNumRows() / k, 1), 8);
-				// final int rowBlockSize = 4;
+				double[] rowSums = null;
 
 				if(overlapping) {
-					for(AColGroup g : colGroups) {
+					for(AColGroup g : filteredGroups) {
 						MatrixBlock tmpRet = new MatrixBlock(ret.getNumRows(), ret.getNumColumns(), false);
 						tmpRet.allocateDenseBlock();
 						for(int blo = 0; blo < that.getNumRows(); blo += rowBlockSize)
@@ -256,6 +288,12 @@ public class CLALibLeftMultBy {
 
 					}
 					List<Future<MatrixBlock>> futures = pool.invokeAll(tasks);
+					if(containsSDC) {
+						MatrixBlock rowSum = that.rowSum();
+						if(rowSum.isInSparseFormat())
+							rowSum.sparseToDense();
+						rowSums = rowSum.getDenseBlockValues();
+					}
 					pool.shutdown();
 					BinaryOperator op = new BinaryOperator(Plus.getPlusFnObject());
 					for(Future<MatrixBlock> future : futures)
@@ -264,36 +302,40 @@ public class CLALibLeftMultBy {
 				else {
 					if(rowBlockSize > 2) {
 						for(int blo = 0; blo < that.getNumRows(); blo += rowBlockSize) {
-							tasks.add(new LeftMatrixColGroupMultTaskNew(colGroups, that, ret, numColumns, blo,
+							tasks.add(new LeftMatrixColGroupMultTaskNew(filteredGroups, that, ret, numColumns, blo,
 								Math.min(blo + rowBlockSize, that.getNumRows())));
 						}
 					}
 					else {
-
-						List<List<AColGroup>> split = split(colGroups, Math.max(k / that.getNumRows(), 1));
+						List<List<AColGroup>> split = split(filteredGroups, Math.max(k / that.getNumRows(), 1));
 						for(int blo = 0; blo < that.getNumRows(); blo += rowBlockSize) {
 							for(List<AColGroup> gr : split)
 								tasks.add(new LeftMatrixColGroupMultTaskNew(gr, that, ret, numColumns, blo,
 									Math.min(blo + rowBlockSize, that.getNumRows())));
 						}
-
-						// for(AColGroup g : colGroups)
-						// for(int blo = 0; blo < that.getNumRows(); blo += rowBlockSize)
-						// tasks.add(new LeftMatrixColGroupMultTaskOld(g, that, ret, blo,
-						// Math.min(blo + rowBlockSize, that.getNumRows()), maxNumValues));
 					}
 
 					List<Future<MatrixBlock>> futures = pool.invokeAll(tasks);
+					if(containsSDC) {
+						MatrixBlock rowSum = that.rowSum();
+						if(rowSum.isInSparseFormat())
+							rowSum.sparseToDense();
+						rowSums = rowSum.getDenseBlockValues();
+					}
 					pool.shutdown();
 					for(Future<MatrixBlock> future : futures)
 						future.get();
 				}
 
+				if(containsSDC)
+					outerProduct(rowSums, constV, ret.getDenseBlockValues());
+
 			}
 			catch(InterruptedException | ExecutionException e) {
 				throw new DMLRuntimeException(e);
 			}
 		}
+
 		ret.recomputeNonZeros();
 		return ret;
 	}
@@ -311,6 +353,16 @@ public class CLALibLeftMultBy {
 		return ret;
 	}
 
+	private static void outerProduct(final double[] leftRowSum, final double[] rightColumnSum, final double[] result) {
+		for(int row = 0; row < leftRowSum.length; row++) {
+			final int offOut = rightColumnSum.length * row;
+			final double vLeft = leftRowSum[row];
+			for(int col = 0; col < rightColumnSum.length; col++) {
+				result[offOut + col] += vLeft * rightColumnSum[col];
+			}
+		}
+	}
+
 	private static class LeftMatrixColGroupMultTaskOld implements Callable<MatrixBlock> {
 		private final AColGroup _group;
 		private final MatrixBlock _that;
@@ -379,47 +431,44 @@ public class CLALibLeftMultBy {
 			}
 		}
 		else {
-			List<ColGroupValue> v = new ArrayList<>();
-			int rowBlockSize = 1;
-			List<MatrixBlock> preAgg = new ArrayList<>();
-			int colGroupBlocking = 16;
-			for(int j = 0; j < colGroupBlocking; j++) {
-				MatrixBlock m = new MatrixBlock(1, 1, false);
-				m.allocateDenseBlock();
-				preAgg.add(m);
-			}
+			// The number of rows to process together
+			final int rowBlockSize = 1;
+			// The number of column groups to process together
+			final int colGroupBlocking = 16;
+
+			// Allocate pre Aggregate Array List
+			final List<MatrixBlock> preAgg = populatePreAggregate(colGroupBlocking);
+			// Allocate a ColGroupValue array for the Column Groups of Value Type.
+			final List<ColGroupValue> ColGroupValues = preFilterAndMultiply(colGroups, that, ret, numColumns, rl, ru);
 
+			// Allocate temporary Result matrix.
 			MatrixBlock tmpRes = new MatrixBlock(rowBlockSize, numColumns, false);
 
-			for(int j = 0; j < colGroups.size(); j++) {
-				AColGroup a = colGroups.get(j);
-				if(a instanceof ColGroupValue) {
-					ColGroupValue av = (ColGroupValue) a;
-					v.add(av);
-				}
-				else
-					a.leftMultByMatrix(that, ret, rl, ru);
-			}
-			Collections.sort(v, Comparator.comparing(AColGroup::getNumValues).reversed());
-			// LOG.error(v);
-			for(int g = 0; g < v.size(); g += colGroupBlocking) {
+			for(int g = 0; g < ColGroupValues.size(); g += colGroupBlocking) {
 				final int gEnd = Math.min(g + colGroupBlocking, colGroups.size());
-				for(int j = g; j < gEnd && j < v.size(); j++) {
-					ColGroupValue cg = v.get(j);
-					preAgg.get(j % colGroupBlocking).reset(rowBlockSize, cg.getNumValues(), false);
+
+				// for each column group in the current block allocate the preaggregate array.
+				for(int j = g; j < gEnd && j < ColGroupValues.size(); j++) {
+					ColGroupValue cg = ColGroupValues.get(j);
+					int nVals = cg.getNumValues();
+					preAgg.get(j % colGroupBlocking).reset(rowBlockSize, nVals, false);
 				}
+
 				// int colBlockSize = 16000;
 				int colBlockSize = 64000;
 
+				// For each row block
 				for(int h = rl; h < ru; h += rowBlockSize) {
+					// For each column block
 					for(int i = 0; i < that.getNumColumns(); i += colBlockSize) {
-						for(int j = g; j < gEnd && j < v.size(); j++) {
-							v.get(j).preAggregateDense(that, preAgg.get(j % colGroupBlocking), h,
+						// Pre Aggregate each column group in block
+						for(int j = g; j < gEnd && j < ColGroupValues.size(); j++) {
+							ColGroupValues.get(j).preAggregateDense(that, preAgg.get(j % colGroupBlocking), h,
 								Math.min(h + rowBlockSize, ru), i, Math.min(i + colBlockSize, that.getNumColumns()));
 						}
 					}
-					for(int j = g; j < gEnd && j < v.size(); j++) {
-						ColGroupValue vj = v.get(j);
+					for(int j = g; j < gEnd && j < ColGroupValues.size(); j++) {
+						ColGroupValue vj = ColGroupValues.get(j);
 						MatrixBlock preAggJ = preAgg.get(j % colGroupBlocking);
 						preAggJ.recomputeNonZeros();
 						tmpRes.reset(rowBlockSize, vj.getNumCols(), false);
@@ -431,4 +480,32 @@ public class CLALibLeftMultBy {
 			}
 		}
 	}
+
+	private static List<MatrixBlock> populatePreAggregate(int colGroupBlocking) {
+		final List<MatrixBlock> preAgg = new ArrayList<>();
+		// poplate the preAgg array.
+		for(int j = 0; j < colGroupBlocking; j++) {
+
+			MatrixBlock m = new MatrixBlock(1, 1, false);
+			m.allocateDenseBlock();
+			preAgg.add(m);
+		}
+		return preAgg;
+	}
+
+	private static List<ColGroupValue> preFilterAndMultiply(List<AColGroup> colGroups, MatrixBlock that,
+		MatrixBlock ret, int numColumns, int rl, int ru) {
+		final List<ColGroupValue> ColGroupValues = new ArrayList<>();
+		for(int j = 0; j < colGroups.size(); j++) {
+			AColGroup a = colGroups.get(j);
+			if(a instanceof ColGroupValue) {
+				ColGroupValue av = (ColGroupValue) a;
+				ColGroupValues.add(av);
+			}
+			else
+				a.leftMultByMatrix(that, ret, rl, ru);
+		}
+		Collections.sort(ColGroupValues, Comparator.comparing(AColGroup::getNumValues).reversed());
+		return ColGroupValues;
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 8d2d672..160f8e9 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -975,13 +975,23 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 	}
 
 	/**
-	 * Wrapper method for reduceall-colSum of a matrix.
+	 * Wrapper method for single threaded reduceall-colSum of a matrix.
 	 * 
 	 * @return A new MatrixBlock containing the column sums of this matrix.
 	 */
 	public MatrixBlock colSum() {
 		AggregateUnaryOperator op = InstructionUtils.parseBasicAggregateUnaryOperator("uack+", 1);
-		return aggregateUnaryOperations(op, null, 1000, null);
+		return aggregateUnaryOperations(op, null, 1000, null, true);
+	}
+
+	/**
+	 * Wrapper method for single threaded reduceall-rowSum of a matrix.
+	 * 
+	 * @return A new MatrixBlock containing the row sums of this matrix.
+	 */
+	public MatrixBlock rowSum(){
+		AggregateUnaryOperator op = InstructionUtils.parseBasicAggregateUnaryOperator("uark+", 1);
+		return aggregateUnaryOperations(op, null, 1000, null, true);
 	}
 
 	/**