You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/08/29 21:50:33 UTC

[systemds] 01/02: [MINOR] CLA update tsmm

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 6caa9c02e81de88f691763f25e93497b0b0d2381
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Sun Aug 29 23:35:48 2021 +0200

    [MINOR] CLA update tsmm
    
    This commit does two things, first it optimize the tsmm by exploiting
    common elements in SDC groups, and secound it update the cost calculation
    to compute some cost of for single column groups.
---
 .../sysds/runtime/compress/colgroup/AColGroup.java |   7 +
 .../compress/colgroup/ColGroupCompressed.java      |   6 +-
 .../runtime/compress/colgroup/ColGroupSDC.java     |   3 +-
 .../compress/colgroup/ColGroupSDCZeros.java        |  21 +-
 .../compress/colgroup/ColGroupUncompressed.java    |  15 ++
 .../compress/cost/ComputationCostEstimator.java    |   7 +-
 .../runtime/compress/lib/CLALibLeftMultBy.java     | 219 ++++++++++++++++-----
 7 files changed, 222 insertions(+), 56 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
index e675426..0460cf2 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
@@ -516,6 +516,13 @@ public abstract class AColGroup implements Serializable {
 	 */
 	public abstract AColGroup replace(double pattern, double replace);
 
+	/**
+	 * Compute the column sum
+	 * 
+	 * @param c The array to add the column sum to.
+	 */
+	public abstract void computeColSums(double[] c);
+
 	@Override
 	public String toString() {
 		StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupCompressed.java
index 968e261..c060596 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupCompressed.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupCompressed.java
@@ -36,7 +36,7 @@ import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
 public abstract class ColGroupCompressed extends AColGroup {
 
 	private static final long serialVersionUID = 6219835795420081223L;
-	
+
 	final protected int _numRows;
 
 	protected ColGroupCompressed(int numRows) {
@@ -72,6 +72,10 @@ public abstract class ColGroupCompressed extends AColGroup {
 
 	protected abstract void computeRowSums(double[] c, boolean square, int rl, int ru);
 
+	public void computeColSums(double[] c){
+		computeColSums(c, false);
+	}
+
 	protected abstract void computeColSums(double[] c, boolean square);
 
 	protected abstract void computeRowMxx(double[] c, Builtin builtin, int rl, int ru);
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index ffe38b1..5c6d365 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -507,11 +507,10 @@ public class ColGroupSDC extends ColGroupValue {
 
 	@Override
 	public Dictionary preAggregateThatSDCZerosStructure(ColGroupSDCZeros that, Dictionary ret) {
-
 		final AIterator itThat = that._indexes.getIterator();
 		final AIterator itThis = _indexes.getIterator();
 		final int nCol = that._colIndexes.length;
-		final int defThis = this.getNumValues() * nCol - nCol;
+		final int defThis = getNumValues() - 1;
 
 		while(itThat.hasNext()) {
 			final int thatV = itThat.value();
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
index 030cf83..2397fd1 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
@@ -436,7 +436,26 @@ public class ColGroupSDCZeros extends ColGroupValue {
 
 	@Override
 	public Dictionary preAggregateThatSDCStructure(ColGroupSDC that, Dictionary ret, boolean preModified) {
-		throw new NotImplementedException();
+		if(preModified){
+			final AIterator itThat = that._indexes.getIterator();
+			final AIterator itThis = _indexes.getIterator();
+			final int nCol = that._colIndexes.length;
+	
+			while(itThat.hasNext() && itThis.hasNext()) {
+				if(itThat.value() == itThis.value()) {
+					final int fr = that.getIndex(itThat.getDataIndexAndIncrement());
+					final int to = getIndex(itThis.getDataIndexAndIncrement());
+					that._dict.addToEntry(ret, fr, to, nCol);
+				}
+				else if(itThat.value() < itThis.value())
+					itThat.next();
+				else
+					itThis.next();
+			}
+			return ret;
+		}else{
+			throw new NotImplementedException("Not implemented not PreModded preaggregate of SDC");
+		}
 	}
 
 	@Override
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
index a9412d0..ab94646 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
@@ -620,4 +620,19 @@ public class ColGroupUncompressed extends AColGroup {
 		MatrixBlock replaced = _data.replaceOperations(new MatrixBlock(), pattern, replace);
 		return new ColGroupUncompressed(_colIndexes, replaced);
 	}
+
+	@Override
+	public void computeColSums(double[] c) {
+		// TODO Auto-generated method stub
+		MatrixBlock colSum = _data.colSum();
+		if(colSum.isInSparseFormat()) {
+			throw new NotImplementedException();
+		}
+		else {
+			double[] dv = colSum.getDenseBlockValues();
+			for(int i = 0; i < _colIndexes.length; i++)
+				c[_colIndexes[i]] += dv[i];
+			
+		}
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
index fe4cb83..0152edc 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
@@ -79,9 +79,12 @@ public class ComputationCostEstimator implements ICostEstimate {
 		cost += _decompressions * decompressionCost(g);
 		cost += _overlappingDecompressions * overlappingDecompressionCost(g);
 		// 16 is assuming that the left side is 16 rows.
-		cost += _leftMultiplications * leftMultCost(g) * 16;
+		double lmc = leftMultCost(g) * 16;
+		cost += _leftMultiplications * lmc;
 		// 16 is assuming that the right side is 16 rows.
-		cost += _rightMultiplications * rightMultCost(g) * 16;
+		double rmc = rightMultCost(g) * 16;
+		cost += _rightMultiplications * rmc;
+		cost += _compressedMultiplication * (lmc + rmc);
 		cost += _dictionaryOps * dictionaryOpsCost(g);
 		return cost;
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index 10cd2e8..c0b4f40 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -91,53 +91,65 @@ public class CLALibLeftMultBy {
 	}
 
 	public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock result, int k) {
-		final int numColumns = cmb.getNumColumns();
 		final boolean overlapping = cmb.isOverlapping();
-		List<AColGroup> groups = cmb.getColGroups();
+		final List<AColGroup> groups = cmb.getColGroups();
+
 		result.allocateDenseBlock();
 
 		if(overlapping) {
 			LOG.warn("Inefficient TSMM with overlapping matrix could be implemented multi-threaded but is not yet.");
 			leftMultByCompressedTransposedMatrix(groups, groups, result);
 		}
-		else if(k <= 1) {
-			for(int i = 0; i < groups.size(); i++)
-				leftMultByCompressedTransposedMatrix(groups.get(i), groups, result, i, groups.size());
-		}
 		else {
-			try {
-				ExecutorService pool = CommonThreadPool.get(k);
-				ArrayList<Callable<Object>> tasks = new ArrayList<>();
+			final boolean containsSDC = containsSDC(groups);
+			final int numColumns = cmb.getNumColumns();
+			final double[] constV = containsSDC ? new double[cmb.getNumColumns()] : null;
+			final List<AColGroup> filteredGroups = filterSDCGroups(groups, constV);
+			final double[] colSums = containsSDC ? new double[cmb.getNumColumns()] : null;
 
+			if(containsSDC)
 				for(int i = 0; i < groups.size(); i++) {
-					final AColGroup g = groups.get(i);
-					tasks.add(new LeftMultByCompressedTransposedMatrixTask(groups, g, result, i, groups.size()));
+					AColGroup gi = groups.get(i);
+					if(!(gi instanceof ColGroupSDC || gi instanceof ColGroupSDCSingle))
+						gi.computeColSums(colSums);
 				}
 
-				for(Future<Object> tret : pool.invokeAll(tasks))
-					tret.get();
-				pool.shutdown();
-			}
-			catch(InterruptedException | ExecutionException e) {
-				throw new DMLRuntimeException(e);
+			if(k <= 1)
+				tsmmColGroups(groups, filteredGroups, result);
+			else
+				tsmmColGroupsParallel(groups, filteredGroups, result, k);
+
+			double[] retV = result.getDenseBlockValues();
+
+			// Move values in the lower part of the matrix to the upper part
+			copyToUpperTriangle(retV, numColumns);
+
+			// add the correction layer for the subtracted common values.
+			if(colSums != null) {
+				outerProduct(colSums, constV, retV);
+				addToUpperTriangle(retV, numColumns);
 			}
 		}
-		// Move values in the lower part of the matrix to the upper part
-		copyToUpperTriangle(result.getDenseBlockValues(), numColumns);
-		// calculate the number of non zeros, and allocate all value locations by copying upper triangle back to bottom.
+
 		long nnz = LinearAlgebraUtils.copyUpperToLowerTriangle(result);
 		result.setNonZeros(nnz);
-		// Evaluate if the output should be sparsely allocated.
 		result.examSparsity();
 	}
 
 	private static void copyToUpperTriangle(final double[] c, final int cols) {
 		for(int i = 0, offC = 0; i < cols; i++, offC += cols)
-			for(int j = i, offR = i * cols; j < cols; j++, offR += cols) {
+			for(int j = (i + 1), offR = (i + 1) * cols; j < cols; j++, offR += cols) {
 				final double prev = c[offC + j];
 				if(prev == 0)
 					c[offC + j] = c[i + offR];
+				c[i + offR] = 0;
 			}
+	}
+
+	private static void addToUpperTriangle(final double[] c, final int cols) {
+		for(int i = 0, offC = 0; i < cols; i++, offC += cols)
+			for(int j = (i + 1), offR = (i + 1) * cols; j < cols; j++, offR += cols)
+				c[offC + j] += c[i + offR];
 
 	}
 
@@ -181,15 +193,6 @@ public class CLALibLeftMultBy {
 		private final int _start;
 		private final int _end;
 
-		protected LeftMultByCompressedTransposedMatrixTask(List<AColGroup> groups, AColGroup left, MatrixBlock ret,
-			int start, int end) {
-			_groups = groups;
-			_left = left;
-			_ret = ret;
-			_start = start;
-			_end = end;
-		}
-
 		protected LeftMultByCompressedTransposedMatrixTask(List<AColGroup> groups, AColGroup left, MatrixBlock ret) {
 			_groups = groups;
 			_left = left;
@@ -227,9 +230,85 @@ public class CLALibLeftMultBy {
 			else
 				rhs.tsmm(ret);
 		}
+	}
+
+	private static void tsmmColGroups(List<AColGroup> groups, List<AColGroup> filteredGroups, MatrixBlock ret) {
+		for(int i = 0; i < groups.size(); i++)
+			tsmmColGroupsIndexI(groups, filteredGroups, ret, i);
+	}
+
+	private static void tsmmColGroupsParallel(List<AColGroup> groups, List<AColGroup> filteredGroups, MatrixBlock ret,
+		int k) {
+		try {
+			ExecutorService pool = CommonThreadPool.get(k);
+			ArrayList<Callable<Object>> tasks = new ArrayList<>();
+
+			for(int i = 0; i < filteredGroups.size(); i++)
+				tasks.add(new tsmmColGroupTask(groups, filteredGroups, ret, i));
+
+			for(Future<Object> tret : pool.invokeAll(tasks))
+				tret.get();
+			pool.shutdown();
+		}
+		catch(InterruptedException | ExecutionException e) {
+			throw new DMLRuntimeException(e);
+		}
+	}
 
+	private static void tsmmColGroupsIndexI(List<AColGroup> groups, List<AColGroup> filteredGroups, MatrixBlock ret,
+		int i) {
+		final AColGroup full_lhs = groups.get(i);
+		final AColGroup lhs = filteredGroups.get(i);
+		final int start = i;
+		final int end = groups.size();
+		full_lhs.tsmm(ret);
+		boolean isSDC = full_lhs instanceof ColGroupSDC || full_lhs instanceof ColGroupSDCSingle;
+		// if(isSDC) {
+		// Arrays.fill(tmp, 0);
+		// full_lhs.computeColSums(tmp);
+		// }
+		for(int id = start + 1; id < end; id++) {
+			final AColGroup full_rhs = groups.get(id);
+			final AColGroup rhs = filteredGroups.get(id);
+			if(isSDC && (full_rhs instanceof ColGroupSDC || full_rhs instanceof ColGroupSDCSingle)) {
+				// Full
+				full_lhs.leftMultByAColGroup(full_rhs, ret);
+
+				// Partial
+				// full_lhs.leftMultByAColGroup(rhs, ret);
+				// multiplyWithMostCommonElement(tmp, (ColGroupValue) full_rhs, ret);
+			}
+			else {
+				lhs.leftMultByAColGroup(rhs, ret);
+			}
+		}
 	}
 
+	// private static void multiplyWithMostCommonElement(double[] colSum, ColGroupValue full, MatrixBlock ret) {
+	// final ADictionary d = full.getDictionary();
+	// final double[] result = ret.getDenseBlockValues();
+	// final int numVals = full.getNumValues();
+	// final int[] colIndexes = full.getColIndices();
+	// final int numColumns = ret.getNumColumns();
+	// if(d instanceof MatrixBlockDictionary && ((MatrixBlockDictionary) d).getMatrixBlock().isInSparseFormat()) {
+	// throw new NotImplementedException();
+	// }
+	// else {
+	// final int offsetToDefault = numVals * full.getNumCols() - numVals;
+	// final double[] dv = d.getValues();
+	// for(int row = 0; row < colSum.length; row++) {
+
+	// final int offOut = numColumns * row;
+	// final double vLeft = colSum[row];
+	// if(vLeft != 0) {
+	// for(int colId = 0; colId < colIndexes.length; colId++) {
+	// result[offOut + colIndexes[colId]] += vLeft * dv[offsetToDefault + colId];
+	// }
+	// }
+	// }
+	// }
+	// }
+
 	private static MatrixBlock leftMultByMatrix(List<AColGroup> colGroups, MatrixBlock that, MatrixBlock ret, int k,
 		boolean overlapping) {
 
@@ -237,28 +316,13 @@ public class CLALibLeftMultBy {
 			ret.setNonZeros(0);
 			return ret;
 		}
-		final int numColumnsOut = ret.getNumColumns();
-		boolean containsSDC = false;
 
-		for(AColGroup g : colGroups) {
-			if(g instanceof ColGroupSDC || g instanceof ColGroupSDCSingle)
-				containsSDC = true;
-		}
+		final int numColumnsOut = ret.getNumColumns();
+		final boolean containsSDC = containsSDC(colGroups);
 
-		final List<AColGroup> filteredGroups = containsSDC ? new ArrayList<>() : colGroups;
 		// a constant colgroup summing the default values.
 		final double[] constV = containsSDC ? new double[numColumnsOut] : null;
-
-		if(containsSDC) {
-			for(AColGroup g : colGroups) {
-				if(g instanceof ColGroupSDC)
-					filteredGroups.add(((ColGroupSDC) g).extractCommon(constV));
-				else if(g instanceof ColGroupSDCSingle)
-					filteredGroups.add(((ColGroupSDCSingle) g).extractCommon(constV));
-				else
-					filteredGroups.add(g);
-			}
-		}
+		final List<AColGroup> filteredGroups = filterSDCGroups(colGroups, constV);
 
 		ret.allocateDenseBlock();
 		final double[] rowSums = containsSDC ? new double[that.getNumRows()] : null;
@@ -418,6 +482,32 @@ public class CLALibLeftMultBy {
 		}
 	}
 
+	private static class tsmmColGroupTask implements Callable<Object> {
+		private final List<AColGroup> _groups;
+		private final List<AColGroup> _filteredGroups;
+		private final MatrixBlock _ret;
+		private final int _index;
+
+		protected tsmmColGroupTask(List<AColGroup> groups, List<AColGroup> filteredGroups, MatrixBlock ret, int i) {
+			_groups = groups;
+			_filteredGroups = filteredGroups;
+			_ret = ret;
+			_index = i;
+		}
+
+		@Override
+		public MatrixBlock call() {
+			try {
+				tsmmColGroupsIndexI(_groups, _filteredGroups, _ret, _index);
+			}
+			catch(Exception e) {
+				e.printStackTrace();
+				throw new DMLRuntimeException(e);
+			}
+			return _ret;
+		}
+	}
+
 	private static void leftMultByMatrixPrimitive(List<AColGroup> colGroups, MatrixBlock that, MatrixBlock ret, int rl,
 		int ru, double[] rowSums) {
 		if(that.isInSparseFormat())
@@ -435,7 +525,7 @@ public class CLALibLeftMultBy {
 			}
 			if(rowSum != null) {
 				final SparseBlock sb = that.getSparseBlock();
-				if(!sb.isEmpty(i)){
+				if(!sb.isEmpty(i)) {
 					final int apos = sb.pos(i);
 					final int alen = sb.size(i) + apos;
 					final double[] aval = sb.values(i);
@@ -538,4 +628,33 @@ public class CLALibLeftMultBy {
 		Collections.sort(ColGroupValues, Comparator.comparing(AColGroup::getNumValues).reversed());
 		return ColGroupValues;
 	}
+
+	private static boolean containsSDC(List<AColGroup> groups) {
+		boolean containsSDC = false;
+
+		for(AColGroup g : groups) {
+			if(g instanceof ColGroupSDC || g instanceof ColGroupSDCSingle) {
+				containsSDC = true;
+				break;
+			}
+		}
+		return containsSDC;
+	}
+
+	private static List<AColGroup> filterSDCGroups(List<AColGroup> groups, double[] constV) {
+		if(constV != null) {
+			final List<AColGroup> filteredGroups = new ArrayList<>();
+			for(AColGroup g : groups) {
+				if(g instanceof ColGroupSDC)
+					filteredGroups.add(((ColGroupSDC) g).extractCommon(constV));
+				else if(g instanceof ColGroupSDCSingle)
+					filteredGroups.add(((ColGroupSDCSingle) g).extractCommon(constV));
+				else
+					filteredGroups.add(g);
+			}
+			return filteredGroups;
+		}
+		else
+			return groups;
+	}
 }