You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2022/01/06 00:57:57 UTC

[systemds] branch main updated: [SYSTEMDS-3267] Explain for transformencode task-graph

This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 636a683  [SYSTEMDS-3267] Explain for transformencode task-graph
636a683 is described below

commit 636a683a07b0a377289f0c83922abcd44c37a7f8
Author: arnabp <ar...@tugraz.at>
AuthorDate: Thu Jan 6 01:57:39 2022 +0100

    [SYSTEMDS-3267] Explain for transformencode task-graph
    
    This patch adds a method to print the task-graph of
    transformencode. Moreover, this commit integrates
    getMetadata tasks within the task-graph.
    
    Closes #1498
---
 .../runtime/transform/encode/ColumnEncoder.java    |  1 +
 .../transform/encode/ColumnEncoderDummycode.java   |  1 +
 .../transform/encode/ColumnEncoderPassThrough.java |  1 +
 .../transform/encode/MultiColumnEncoder.java       | 91 +++++++++++++++++-----
 .../sysds/runtime/util/DependencyThreadPool.java   | 67 ++++++++++++++--
 src/main/java/org/apache/sysds/utils/Explain.java  |  2 +-
 .../TransformFrameEncodeMultithreadedTest.java     |  9 +++
 7 files changed, 147 insertions(+), 25 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index 9db3772..6c9ac6a 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -132,6 +132,7 @@ public abstract class ColumnEncoder implements Encoder, Comparable<ColumnEncoder
 
 	protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){
 		boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR;
+		mcsr = false; //force CSR for transformencode
 		int index = _colID - 1;
 		// Apply loop tiling to exploit CPU caches
 		double[] codes = getCodeCol(in, rowStart, blk);
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
index 63cf86c..5c1cd11 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
@@ -88,6 +88,7 @@ public class ColumnEncoderDummycode extends ColumnEncoder {
 					" and not MatrixBlock");
 		}
 		boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR;
+		mcsr = false; //force CSR for transformencode
 		Set<Integer> sparseRowsWZeros = null;
 		int index = _colID - 1;
 		for(int r = rowStart; r < getEndIndex(in.getNumRows(), rowStart, blk); r++) {
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
index f8b467d..36784ab 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
@@ -80,6 +80,7 @@ public class ColumnEncoderPassThrough extends ColumnEncoder {
 	protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){
 		Set<Integer> sparseRowsWZeros = null;
 		boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR;
+		mcsr = false; //force CSR for transformencode
 		int index = _colID - 1;
 		// Apply loop tiling to exploit CPU caches
 		double[] codes = getCodeCol(in, rowStart, blk);
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 76301ce..52bad53 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -43,6 +43,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -134,46 +135,67 @@ public class MultiColumnEncoder implements Encoder {
 		return out;
 	}
 
+	/* TASK DETAILS:
+	 * InitOutputMatrixTask:        Allocate output matrix
+	 * AllocMetaTask:               Allocate metadata frame
+	 * BuildTask:                   Build an encoder
+	 * ColumnCompositeUpdateDCTask: Update domain size of a DC encoder based on #distincts, #bins, K
+	 * ColumnMetaDataTask:          Fill up metadata of an encoder
+	 * ApplyTasksWrapperTask:       Wrapper task for an Apply
+	 * UpdateOutputColTask:         Sets starting offsets of the DC columns
+	 */
 	private List<DependencyTask<?>> getEncodeTasks(CacheBlock in, MatrixBlock out, DependencyThreadPool pool) {
 		List<DependencyTask<?>> tasks = new ArrayList<>();
 		List<DependencyTask<?>> applyTAgg = null;
 		Map<Integer[], Integer[]> depMap = new HashMap<>();
 		boolean hasDC = getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
 		boolean applyOffsetDep = false;
+		_meta = new FrameBlock(in.getNumColumns(), ValueType.STRING);
+		// Create the output and metadata allocation tasks
 		tasks.add(DependencyThreadPool.createDependencyTask(new InitOutputMatrixTask(this, in, out)));
+		tasks.add(DependencyThreadPool.createDependencyTask(new AllocMetaTask(this, _meta)));
+
 		for(ColumnEncoderComposite e : _columnEncoders) {
+			// Create the build tasks
 			List<DependencyTask<?>> buildTasks = e.getBuildTasks(in);
-
 			tasks.addAll(buildTasks);
 			if(buildTasks.size() > 0) {
-				// Apply Task dependency to build completion task
-				depMap.put(new Integer[] {tasks.size(), tasks.size() + 1},
-					new Integer[] {tasks.size() - 1, tasks.size()});
+				// Apply Task depends on build completion task
+				depMap.put(new Integer[] {tasks.size(), tasks.size() + 1},      //ApplyTask
+					new Integer[] {tasks.size() - 1, tasks.size()});            //BuildTask
+				// getMetaDataTask depends on build completion
+				depMap.put(new Integer[] {tasks.size() + 1, tasks.size() + 2}, //MetaDataTask
+					new Integer[] {tasks.size() - 1, tasks.size()});           //BuildTask
+				// getMetaDataTask depends on AllocMeta task
+				depMap.put(new Integer[] {tasks.size() + 1, tasks.size() + 2}, //MetaDataTask
+					new Integer[] {1, 2});                                     //AllocMetaTask (2nd task)
+				// AllocMetaTask depends on the build completion tasks
+				depMap.put(new Integer[] {1, 2},                               //AllocMetaTask (2nd task)
+					new Integer[] {tasks.size() - 1, tasks.size()});           //BuildTask
 			}
 
-			// Apply Task dependency to InitOutputMatrixTask
-			depMap.put(new Integer[] {tasks.size(), tasks.size() + 1}, new Integer[] {0, 1});
+			// Apply Task depends on InitOutputMatrixTask (output allocation)
+			depMap.put(new Integer[] {tasks.size(), tasks.size() + 1},         //ApplyTask
+					new Integer[] {0, 1});                                     //Allocation task (1st task)
 			ApplyTasksWrapperTask applyTaskWrapper = new ApplyTasksWrapperTask(e, in, out, pool);
 
 			if(e.hasEncoder(ColumnEncoderDummycode.class)) {
-				// InitMatrix dependency to build of recode if a DC is present
-				// Since they are the only ones that change the domain size which would influence the Matrix creation
-				depMap.put(new Integer[] {0, 1}, // InitMatrix Task first in list
-					new Integer[] {tasks.size() - 1, tasks.size()});
-				// output col update task dependent on Build completion only for Recode and binning since they can
-				// change dummycode domain size
-				// colUpdateTask can start when all domain sizes, because it can now calculate the offsets for
-				// each column
-				depMap.put(new Integer[] {-2, -1}, new Integer[] {tasks.size() - 1, tasks.size()});
+				// Allocation depends on build if DC is in the list.
+				// Note, DC is the only encoder that changes dimensionality
+				depMap.put(new Integer[] {0, 1},                               //Allocation task (1st task)
+					new Integer[] {tasks.size() - 1, tasks.size()});           //BuildTask
+				// UpdateOutputColTask, that sets the starting offsets of the DC columns,
+				// depends on the Build completion tasks
+				depMap.put(new Integer[] {-2, -1},                             //UpdateOutputColTask (last task) 
+						new Integer[] {tasks.size() - 1, tasks.size()});       //BuildTask
 				buildTasks.forEach(t -> t.setPriority(5));
 				applyOffsetDep = true;
 			}
 
 			if(hasDC && applyOffsetDep) {
-				// Apply Task dependency to output col update task (is last in list)
-				// All ApplyTasks need to wait for this task, so they all have the correct offsets.
-				// But only for the columns that come after the first DC coder since they don't have an offset
-				depMap.put(new Integer[] {tasks.size(), tasks.size() + 1}, new Integer[] {-2, -1});
+				// Apply tasks depend on UpdateOutputColTask
+				depMap.put(new Integer[] {tasks.size(), tasks.size() + 1},     //ApplyTask 
+						new Integer[] {-2, -1});                               //UpdateOutputColTask (last task)
 
 				applyTAgg = applyTAgg == null ? new ArrayList<>() : applyTAgg;
 				applyTAgg.add(applyTaskWrapper);
@@ -181,9 +203,13 @@ public class MultiColumnEncoder implements Encoder {
 			else {
 				applyTaskWrapper.setOffset(0);
 			}
+			// Create the ApplyTask (wrapper)
 			tasks.add(applyTaskWrapper);
+			// Create the getMetadata task
+			tasks.add(DependencyThreadPool.createDependencyTask(new ColumnMetaDataTask<ColumnEncoder>(e, _meta)));
 		}
 		if(hasDC)
+			// Create the last task, UpdateOutputColTask
 			tasks.add(DependencyThreadPool.createDependencyTask(new UpdateOutputColTask(this, applyTAgg)));
 
 		List<List<? extends Callable<?>>> deps = new ArrayList<>(Collections.nCopies(tasks.size(), null));
@@ -330,6 +356,7 @@ public class MultiColumnEncoder implements Encoder {
 					&& MatrixBlock.DEFAULT_SPARSEBLOCK != SparseBlock.Type.MCSR)
 				throw new RuntimeException("Transformapply is only supported for MCSR and CSR output matrix");
 			boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR;
+			mcsr = false; //force CSR for transformencode
 			if (mcsr) {
 				output.allocateBlock();
 				SparseBlock block = output.getSparseBlock();
@@ -933,6 +960,27 @@ public class MultiColumnEncoder implements Encoder {
 			return null;
 		}
 	}
+
+	private static class AllocMetaTask implements Callable<Object> {
+		private final MultiColumnEncoder _encoder;
+		private final FrameBlock _meta;
+		
+		private AllocMetaTask (MultiColumnEncoder encoder, FrameBlock meta) {
+			_encoder = encoder;
+			_meta = meta;
+		}
+
+		@Override
+		public Object call() throws Exception {
+			_encoder.allocateMetaData(_meta);
+			return null;
+		}
+
+		@Override
+		public String toString() {
+			return getClass().getSimpleName();
+		}
+	}
 	
 	private static class ColumnMetaDataTask<T extends ColumnEncoder> implements Callable<Object> {
 		private final T _colEncoder;
@@ -948,6 +996,11 @@ public class MultiColumnEncoder implements Encoder {
 			_colEncoder.getMetaData(_out);
 			return null;
 		}
+
+		@Override
+		public String toString() {
+			return getClass().getSimpleName() + "<ColId: " + _colEncoder._colID + ">";
+		}
 	}
 
 }
diff --git a/src/main/java/org/apache/sysds/runtime/util/DependencyThreadPool.java b/src/main/java/org/apache/sysds/runtime/util/DependencyThreadPool.java
index 50675d6..90d1dfc 100644
--- a/src/main/java/org/apache/sysds/runtime/util/DependencyThreadPool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/DependencyThreadPool.java
@@ -22,8 +22,10 @@ package org.apache.sysds.runtime.util;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.utils.Explain;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -90,7 +92,10 @@ public class DependencyThreadPool {
 	public List<Object> submitAllAndWait(List<DependencyTask<?>> dtasks)
 		throws ExecutionException, InterruptedException {
 		List<Object> res = new ArrayList<>();
-		// printDependencyGraph(dtasks);
+		if(DependencyTask.ENABLE_DEBUG_DATA) {
+			if (dtasks != null && dtasks.size() > 0)
+				explainTaskGraph(dtasks);
+		}
 		List<Future<Future<?>>> futures = submitAll(dtasks);
 		int i = 0;
 		for(Future<Future<?>> ff : futures) {
@@ -112,10 +117,12 @@ public class DependencyThreadPool {
 	}
 
 	/*
-	 * Creates the Dependency list from a map and the tasks. The map specifies which tasks should have a Dependency on
-	 * which other task. e.g.
-	 * ([0, 3], [4, 6])   means the first 3 tasks in the tasks list are dependent on tasks at index 4 and 5
-	 * ([-2, -1], [0, 5]) means the last task has a Dependency on the first 5 tasks.
+	 * Creates the Dependency list from a map and the tasks. The map specifies which tasks 
+	 * should have a Dependency on which other task. e.g.
+	 * ([0, 3], [4, 6])   means the 1st 3 tasks in the list are dependent on tasks at index 4 and 5
+	 * ([-2, -1], [0, 5]) means the last task depends on the first 5 tasks.
+	 * ([dependent start index, dependent end index (excluding)], 
+	 *  [parent start index, parent end index (excluding)])
 	 */
 	public static List<List<? extends Callable<?>>> createDependencyList(List<? extends Callable<?>> tasks,
 		Map<Integer[], Integer[]> depMap, List<List<? extends Callable<?>>> dep) {
@@ -175,4 +182,54 @@ public class DependencyThreadPool {
 		}
 		return ret;
 	}
+
+	/*
+	 * Prints the task-graph level-wise, however, the printed
+	 * output doesn't specify which task of level l depends
+	 * on which task of level (l-1).
+	 */
+	public static void explainTaskGraph(List<DependencyTask<?>> tasks) {
+		Map<DependencyTask<?>, Integer> levelMap = new HashMap<>();
+		int depth = 1;
+		while (levelMap.size() < tasks.size()) {
+			for (int i=0; i<tasks.size(); i++) {
+				DependencyTask<?> dt = tasks.get(i);
+				if (dt._dependencyTasks == null || dt._dependencyTasks.size() == 0)
+					levelMap.put(dt, 0);
+				if (dt._dependencyTasks != null) {
+					List<DependencyTask<?>> parents = dt._dependencyTasks;
+					int[] parentLevels = new int[parents.size()];
+					boolean missing = false;
+					for (int p=0; p<parents.size(); p++) {
+						if (!levelMap.containsKey(parents.get(p)))
+							missing = true;
+						else
+							parentLevels[p] = levelMap.get(parents.get(p));
+					}
+					if (missing)
+						continue;
+					int maxParentLevel = Arrays.stream(parentLevels).max().getAsInt();
+					levelMap.put(dt, maxParentLevel+1);
+					if (maxParentLevel+1 == depth)
+						depth++;
+				}
+			}
+		}
+		StringBuilder sbs[] = new StringBuilder[depth];
+		String offsets[] = new String[depth];
+		for (Map.Entry<DependencyTask<?>, Integer> entry : levelMap.entrySet()) {
+			int level = entry.getValue();
+			if (sbs[level] == null) {
+				sbs[level] = new StringBuilder();
+				offsets[level] = Explain.createOffset(level);
+			}
+			sbs[level].append(offsets[level]);
+			sbs[level].append(entry.getKey().toString()+"\n");
+		}
+		System.out.println("EXPlAIN (TASK-GRAPH):");
+		for (int i=0; i<sbs.length; i++) {
+			System.out.println(sbs[i].toString());
+		}
+
+	}
 }
diff --git a/src/main/java/org/apache/sysds/utils/Explain.java b/src/main/java/org/apache/sysds/utils/Explain.java
index ae6a523..ba6fb71 100644
--- a/src/main/java/org/apache/sysds/utils/Explain.java
+++ b/src/main/java/org/apache/sysds/utils/Explain.java
@@ -830,7 +830,7 @@ public class Explain
 		return OptimizerUtils.toMB(mem) + (units?"MB":"");
 	}
 
-	private static String createOffset( int level )
+	public static String createOffset( int level )
 	{
 		StringBuilder sb = new StringBuilder();
 		for( int i=0; i<level; i++ )
diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java
index fbf7111..560f934 100644
--- a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeMultithreadedTest.java
@@ -24,6 +24,7 @@ import java.nio.file.Paths;
 
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
 import org.apache.sysds.runtime.io.FrameReaderFactory;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
@@ -211,11 +212,19 @@ public class TransformFrameEncodeMultithreadedTest extends AutomatedTestBase {
 			MultiColumnEncoder.MULTI_THREADED_STAGES = staged;
 
 			MatrixBlock outputS = encoder.encode(input, 1);
+			FrameBlock metaS = encoder.getMetaData(new FrameBlock(input.getNumColumns(), ValueType.STRING), 1);
 			MatrixBlock outputM = encoder.encode(input, 12);
+			FrameBlock metaM = encoder.getMetaData(new FrameBlock(input.getNumColumns(), ValueType.STRING), 12);
 
+			// Match encoded matrices
 			double[][] R1 = DataConverter.convertToDoubleMatrix(outputS);
 			double[][] R2 = DataConverter.convertToDoubleMatrix(outputM);
 			TestUtils.compareMatrices(R1, R2, R1.length, R1[0].length, 0);
+			// Match the metadata frames
+			String[][] M1 = DataConverter.convertToStringFrame(metaS);
+			String[][] M2 = DataConverter.convertToStringFrame(metaM);
+			TestUtils.compareFrames(M1, M2, M1.length, M1[0].length);
+
 			Assert.assertEquals(outputS.getNonZeros(), outputM.getNonZeros());
 			Assert.assertTrue(outputM.getNonZeros() > 0);