You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2022/01/31 08:28:15 UTC
[systemds] branch main updated: [SYSTEMDS-3283] Multi-threaded ctable instruction
This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 09f0be0 [SYSTEMDS-3283] Multi-threaded ctable instruction
09f0be0 is described below
commit 09f0be03c820d6851033bfc6469df7703cac0faa
Author: arnabp <ar...@tugraz.at>
AuthorDate: Mon Jan 31 09:26:46 2022 +0100
[SYSTEMDS-3283] Multi-threaded ctable instruction
This patch implements a multithreaded version of
F = ctable(A, B, W) case. Other cases will be supported
in the future. Each thread constructs a separate
CTableMap from a block of rows. Later we cascade-merge
the partial maps.
This implementation shows 8x improvement for
23M rows with 470K unique values.
Closes #1530.
---
.../sysds/runtime/functionobjects/CTable.java | 143 +++++++++++++++++++++
.../sysds/runtime/matrix/data/MatrixBlock.java | 25 ++--
2 files changed, 157 insertions(+), 11 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java b/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java
index fc44ed1..291effa 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/CTable.java
@@ -24,8 +24,17 @@ import org.apache.sysds.runtime.matrix.data.CTableMap;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.Pair;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+import org.apache.sysds.runtime.util.LongLongDoubleHashMap;
import org.apache.sysds.runtime.util.UtilFunctions;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+
public class CTable extends ValueFunction
{
private static final long serialVersionUID = -5374880447194177236L;
@@ -139,4 +148,138 @@ public class CTable extends ValueFunction
throw new DMLRuntimeException("Erroneous input while computing the contingency table (value <= zero): "+v2);
return new Pair<>(new MatrixIndexes(row, col), w);
}
+
+ /* Multithreaded CTable (F = ctable(A,B,W))
+ * Divide the input vectors into equal-sized blocks and assign each block to a task.
+ * All tasks concurrently build their own CTableMaps.
+ * Cascade merge the partial maps.
+ * TODO: Support other cases
+ */
+ public void execute(MatrixBlock in1, MatrixBlock in2, MatrixBlock w, CTableMap resultMap, int k) {
+ ExecutorService pool = CommonThreadPool.get(k);
+ ArrayList<CTableMap> partialMaps = new ArrayList<>();
+ try {
+ // Assign an equal-sized blocks to each task
+ List<Callable<Object>> tasks = new ArrayList<>();
+ int[] blockSizes = UtilFunctions.getBlockSizes(in1.getNumRows(), k);
+ // Each task builds a separate CTableMap in a lock-free manner
+ for(int startRow = 0, i = 0; i < blockSizes.length; startRow += blockSizes[i], i++)
+ tasks.add(getPartialCTableTask(in1, in2, w, startRow, blockSizes[i], partialMaps));
+ List<Future<Object>> taskret = pool.invokeAll(tasks);
+ for(var task : taskret)
+ task.get();
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+
+ ArrayList<CTableMap> newPartialMaps = new ArrayList<>();
+ // Cascade-merge all the partial CTableMaps
+ while(partialMaps.size() > 1) {
+ newPartialMaps.clear();
+ List<Callable<Object>> tasks = new ArrayList<>();
+ int count;
+ // Each task merges 2 maps and returns the merged map
+ for (count=0; count+1<partialMaps.size(); count=count+2)
+ tasks.add(getMergePartialCTMapsTask(partialMaps.get(count),
+ partialMaps.get(count+1), newPartialMaps));
+
+ try {
+ List<Future<Object>> taskret = pool.invokeAll(tasks);
+ for(var task : taskret)
+ task.get();
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ // Copy the remaining maps to be merged in the future iterations
+ if (count < partialMaps.size())
+ newPartialMaps.add(partialMaps.get(count));
+ partialMaps.clear();
+ partialMaps.addAll(newPartialMaps);
+ }
+ pool.shutdown();
+ // Deep copy the last merged map into the result map
+ var map = partialMaps.get(0);
+ Iterator<LongLongDoubleHashMap.ADoubleEntry> iter = map.getIterator();
+ while(iter.hasNext()) {
+ LongLongDoubleHashMap.ADoubleEntry e = iter.next();
+ resultMap.aggregate(e.getKey1(), e.getKey2(), e.value);
+ }
+ }
+
+ public Callable<Object> getPartialCTableTask(MatrixBlock in1, MatrixBlock in2, MatrixBlock w,
+ int startInd, int blockSize, ArrayList<CTableMap> pmaps) {
+ return new PartialCTableTask(in1, in2, w, startInd, blockSize, pmaps);
+ }
+
+ public Callable<Object> getMergePartialCTMapsTask(CTableMap map1, CTableMap map2, ArrayList<CTableMap> pmaps) {
+ return new MergePartialCTMaps(map1, map2, pmaps);
+ }
+
+ private static class PartialCTableTask implements Callable<Object> {
+ private final MatrixBlock _in1;
+ private final MatrixBlock _in2;
+ private final MatrixBlock _w;
+ private final int _startInd;
+ private final int _blockSize;
+ private final ArrayList<CTableMap> _partialCTmaps;
+
+ protected PartialCTableTask(MatrixBlock in1, MatrixBlock in2, MatrixBlock w,
+ int startRow, int blockSize, ArrayList<CTableMap> pmaps) {
+ _in1 = in1;
+ _in2 = in2;
+ _w = w;
+ _startInd = startRow;
+ _blockSize = blockSize;
+ _partialCTmaps = pmaps;
+ }
+
+ @Override public Object call() throws Exception {
+ CTable ctable = CTable.getCTableFnObject();
+ CTableMap ctmap = new CTableMap(LongLongDoubleHashMap.EntryType.INT);
+ int endInd = UtilFunctions.getEndIndex(_in1.getNumRows(), _startInd, _blockSize);
+ for( int i=_startInd; i<endInd; i++ )
+ {
+ double v1 = _in1.quickGetValue(i, 0);
+ double v2 = _in2.quickGetValue(i, 0);
+ double w = _w.quickGetValue(i, 0);
+ ctable.execute(v1, v2, w, false, ctmap);
+ }
+ synchronized(_partialCTmaps) {
+ _partialCTmaps.add(ctmap);
+ }
+ return null;
+ }
+ }
+
+ private static class MergePartialCTMaps implements Callable<Object> {
+ private final CTableMap _map1;
+ private final CTableMap _map2;
+ private final ArrayList<CTableMap> _partialCTmaps;
+
+ protected MergePartialCTMaps(CTableMap map1, CTableMap map2, ArrayList<CTableMap> pmaps) {
+ _map1 = map1;
+ _map2 = map2;
+ _partialCTmaps = pmaps;
+ }
+
+ private void mergeToFinal(CTableMap map, CTableMap finalMap) {
+ Iterator<LongLongDoubleHashMap.ADoubleEntry> iter = map.getIterator();
+ while(iter.hasNext()) {
+ LongLongDoubleHashMap.ADoubleEntry e = iter.next();
+ finalMap.aggregate(e.getKey1(), e.getKey2(), e.value);
+ }
+ }
+
+ @Override public Object call() throws Exception {
+ CTableMap mergedMap = new CTableMap(LongLongDoubleHashMap.EntryType.INT);
+ mergeToFinal(_map1, mergedMap);
+ mergeToFinal(_map2, mergedMap);
+ synchronized(_partialCTmaps) {
+ _partialCTmaps.add(mergedMap);
+ return null;
+ }
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index a0fcef6..683cd14 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -5501,19 +5501,22 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
MatrixBlock that = checkType(thatVal);
MatrixBlock that2 = checkType(that2Val);
CTable ctable = CTable.getCTableFnObject();
-
+ int k = OptimizerUtils.getTransformNumThreads();
//sparse-unsafe ctable execution
//(because input values of 0 are invalid and have to result in errors)
- if(resultBlock == null)
- {
- for( int i=0; i<rlen; i++ )
- for( int j=0; j<clen; j++ )
- {
- double v1 = this.quickGetValue(i, j);
- double v2 = that.quickGetValue(i, j);
- double w = that2.quickGetValue(i, j);
- ctable.execute(v1, v2, w, false, resultMap);
- }
+ if(resultBlock == null) {
+ if (k > 1 && clen == 1)
+ //TODO: Find the optimum k during compilation
+ ctable.execute(this, that, that2, resultMap, k);
+ else {
+ for(int i = 0; i < rlen; i++)
+ for(int j = 0; j < clen; j++) {
+ double v1 = this.quickGetValue(i, j);
+ double v2 = that.quickGetValue(i, j);
+ double w = that2.quickGetValue(i, j);
+ ctable.execute(v1, v2, w, false, resultMap);
+ }
+ }
}
else
{