You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2022/01/31 08:28:15 UTC

[systemds] branch main updated: [SYSTEMDS-3283] Multi-threaded ctable instruction

This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 09f0be0  [SYSTEMDS-3283] Multi-threaded ctable instruction
09f0be0 is described below

commit 09f0be03c820d6851033bfc6469df7703cac0faa
Author: arnabp <ar...@tugraz.at>
AuthorDate: Mon Jan 31 09:26:46 2022 +0100

    [SYSTEMDS-3283] Multi-threaded ctable instruction
    
    This patch implements a multithreaded version of
    F = ctable(A, B, W) case. Other cases will be supported
    in the future. Each thread constructs a separate
    CTableMap from a block of rows. Later we cascade-merge
    the partial maps.
    This implementation shows 8x improvement for
    23M rows with 470K unique values.
    
    Closes #1530.
---
 .../sysds/runtime/functionobjects/CTable.java      | 143 +++++++++++++++++++++
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  25 ++--
 2 files changed, 157 insertions(+), 11 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java b/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java
index fc44ed1..291effa 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java
@@ -24,8 +24,17 @@ import org.apache.sysds.runtime.matrix.data.CTableMap;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.data.Pair;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+import org.apache.sysds.runtime.util.LongLongDoubleHashMap;
 import org.apache.sysds.runtime.util.UtilFunctions;
 
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+
 public class CTable extends ValueFunction 
 {
 	private static final long serialVersionUID = -5374880447194177236L;
@@ -139,4 +148,138 @@ public class CTable extends ValueFunction
 			throw new DMLRuntimeException("Erroneous input while computing the contingency table (value <= zero): "+v2);
 		return new Pair<>(new MatrixIndexes(row, col), w);
 	}
+
+	/* Multithreaded CTable (F = ctable(A,B,W))
+	 * Divide the input vectors into equal-sized blocks and assign each block to a task.
+	 * All tasks concurrently build their own CTableMaps.
+	 * Cascade merge the partial maps.
+	 * TODO: Support other cases
+	 */
+	public void execute(MatrixBlock in1, MatrixBlock in2, MatrixBlock w, CTableMap resultMap, int k) {
+		ExecutorService pool = CommonThreadPool.get(k);
+		ArrayList<CTableMap> partialMaps = new ArrayList<>();
+		try {
+			// Assign an equal-sized blocks to each task
+			List<Callable<Object>> tasks = new ArrayList<>();
+			int[] blockSizes = UtilFunctions.getBlockSizes(in1.getNumRows(), k);
+			// Each task builds a separate CTableMap in a lock-free manner
+			for(int startRow = 0, i = 0; i < blockSizes.length; startRow += blockSizes[i], i++)
+				tasks.add(getPartialCTableTask(in1, in2, w, startRow, blockSizes[i], partialMaps));
+			List<Future<Object>> taskret = pool.invokeAll(tasks);
+			for(var task : taskret)
+				task.get();
+		}
+		catch(Exception ex) {
+			throw new DMLRuntimeException(ex);
+		}
+
+		ArrayList<CTableMap> newPartialMaps = new ArrayList<>();
+		// Cascade-merge all the partial CTableMaps
+		while(partialMaps.size() > 1) {
+			newPartialMaps.clear();
+			List<Callable<Object>> tasks = new ArrayList<>();
+			int count;
+			// Each task merges 2 maps and returns the merged map
+			for (count=0; count+1<partialMaps.size(); count=count+2)
+				tasks.add(getMergePartialCTMapsTask(partialMaps.get(count),
+					partialMaps.get(count+1), newPartialMaps));
+
+			try {
+				List<Future<Object>> taskret = pool.invokeAll(tasks);
+				for(var task : taskret)
+					task.get();
+			}
+			catch(Exception ex) {
+				throw new DMLRuntimeException(ex);
+			}
+			// Copy the remaining maps to be merged in the future iterations
+			if (count < partialMaps.size())
+				newPartialMaps.add(partialMaps.get(count));
+			partialMaps.clear();
+			partialMaps.addAll(newPartialMaps);
+		}
+		pool.shutdown();
+		// Deep copy the last merged map into the result map
+		var map = partialMaps.get(0);
+		Iterator<LongLongDoubleHashMap.ADoubleEntry> iter = map.getIterator();
+		while(iter.hasNext()) {
+			LongLongDoubleHashMap.ADoubleEntry e = iter.next();
+			resultMap.aggregate(e.getKey1(), e.getKey2(), e.value);
+		}
+	}
+
+	public Callable<Object> getPartialCTableTask(MatrixBlock in1, MatrixBlock in2, MatrixBlock w,
+		int startInd, int blockSize, ArrayList<CTableMap> pmaps) {
+		return new PartialCTableTask(in1, in2, w, startInd, blockSize, pmaps);
+	}
+
+	public Callable<Object> getMergePartialCTMapsTask(CTableMap map1, CTableMap map2, ArrayList<CTableMap> pmaps) {
+		return new MergePartialCTMaps(map1, map2, pmaps);
+	}
+
+	private static class PartialCTableTask implements Callable<Object> {
+		private final MatrixBlock _in1;
+		private final MatrixBlock _in2;
+		private final MatrixBlock _w;
+		private final int _startInd;
+		private final int _blockSize;
+		private final ArrayList<CTableMap> _partialCTmaps;
+
+		protected PartialCTableTask(MatrixBlock in1, MatrixBlock in2, MatrixBlock w,
+			int startRow, int blockSize, ArrayList<CTableMap> pmaps) {
+			_in1 = in1;
+			_in2 = in2;
+			_w = w;
+			_startInd = startRow;
+			_blockSize = blockSize;
+			_partialCTmaps = pmaps;
+		}
+
+		@Override public Object call() throws Exception {
+			CTable ctable = CTable.getCTableFnObject();
+			CTableMap ctmap = new CTableMap(LongLongDoubleHashMap.EntryType.INT);
+			int endInd = UtilFunctions.getEndIndex(_in1.getNumRows(), _startInd, _blockSize);
+			for( int i=_startInd; i<endInd; i++ )
+			{
+				double v1 = _in1.quickGetValue(i, 0);
+				double v2 = _in2.quickGetValue(i, 0);
+				double w = _w.quickGetValue(i, 0);
+				ctable.execute(v1, v2, w, false, ctmap);
+			}
+			synchronized(_partialCTmaps) {
+				_partialCTmaps.add(ctmap);
+			}
+			return null;
+		}
+	}
+
+	private static class MergePartialCTMaps implements Callable<Object> {
+		private final CTableMap _map1;
+		private final CTableMap _map2;
+		private final ArrayList<CTableMap> _partialCTmaps;
+
+		protected MergePartialCTMaps(CTableMap map1, CTableMap map2, ArrayList<CTableMap> pmaps) {
+			_map1 = map1;
+			_map2 = map2;
+			_partialCTmaps = pmaps;
+		}
+
+		private void mergeToFinal(CTableMap map, CTableMap finalMap) {
+			Iterator<LongLongDoubleHashMap.ADoubleEntry> iter = map.getIterator();
+			while(iter.hasNext()) {
+				LongLongDoubleHashMap.ADoubleEntry e = iter.next();
+				finalMap.aggregate(e.getKey1(), e.getKey2(), e.value);
+			}
+		}
+
+		@Override public Object call() throws Exception {
+			CTableMap mergedMap = new CTableMap(LongLongDoubleHashMap.EntryType.INT);
+			mergeToFinal(_map1, mergedMap);
+			mergeToFinal(_map2, mergedMap);
+			synchronized(_partialCTmaps) {
+				_partialCTmaps.add(mergedMap);
+				return null;
+			}
+		}
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index a0fcef6..683cd14 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -5501,19 +5501,22 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 		MatrixBlock that = checkType(thatVal);
 		MatrixBlock that2 = checkType(that2Val);
 		CTable ctable = CTable.getCTableFnObject();
-		
+		int k = OptimizerUtils.getTransformNumThreads();
 		//sparse-unsafe ctable execution
 		//(because input values of 0 are invalid and have to result in errors) 
-		if(resultBlock == null) 
-		{
-			for( int i=0; i<rlen; i++ )
-				for( int j=0; j<clen; j++ )
-				{
-					double v1 = this.quickGetValue(i, j);
-					double v2 = that.quickGetValue(i, j);
-					double w = that2.quickGetValue(i, j);
-					ctable.execute(v1, v2, w, false, resultMap);
-				}		
+		if(resultBlock == null) {
+			if (k > 1 && clen == 1)
+				//TODO: Find the optimum k during compilation
+				ctable.execute(this, that, that2, resultMap, k);
+			else {
+				for(int i = 0; i < rlen; i++)
+					for(int j = 0; j < clen; j++) {
+						double v1 = this.quickGetValue(i, j);
+						double v2 = that.quickGetValue(i, j);
+						double w = that2.quickGetValue(i, j);
+						ctable.execute(v1, v2, w, false, resultMap);
+					}
+			}
 		}
 		else 
 		{