You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2020/04/15 20:11:30 UTC

[systemml] branch master updated: [SYSTEMDS-233] Fix multi-level lineage caching (parfor, determinism)

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 9bd68ff  [SYSTEMDS-233] Fix multi-level lineage caching (parfor, determinism)
9bd68ff is described below

commit 9bd68ffc5d211583a2ebcfe5be514abf4cc29b69
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Wed Apr 15 21:46:16 2020 +0200

    [SYSTEMDS-233] Fix multi-level lineage caching (parfor, determinism)
    
    This patch fixes some issues with multi-level lineage caching in parfor,
    specifically (1) to allow function reuse despite differently named
    parfor worker functions, and (2) the check for deterministic function
    results incorrectly probed too far and thus missing opportunities.
    
    However, down the road we should add an IPA pass which determines once
    for all functions if they are deterministic and pass this information
    down to the runtime, in order to avoid scenarios where threads are
    already blocking on placeholders that are later removed due to
    non-deterministic functions.
---
 .../apache/sysds/hops/recompile/Recompiler.java    | 10 +++++-----
 src/main/java/org/apache/sysds/lops/Lop.java       |  2 +-
 .../sysds/runtime/controlprogram/ProgramBlock.java | 17 ++++++++++++++--
 .../instructions/cp/FunctionCallCPInstruction.java | 23 +++++++++++++++++-----
 .../apache/sysds/runtime/lineage/LineageCache.java | 16 +++++++++------
 .../runtime/lineage/LineageCacheStatistics.java    | 10 +++++++++-
 .../sysds/runtime/lineage/LineageItemUtils.java    | 10 +++-------
 .../java/org/apache/sysds/utils/Statistics.java    |  2 +-
 .../functions/lineage/FunctionFullReuseTest.java   |  7 +++++++
 .../functions/lineage/FunctionFullReuse6.dml       |  4 ++--
 10 files changed, 71 insertions(+), 30 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index 2b11c73..d058c6a 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -155,7 +155,7 @@ public class Recompiler
 		}
 		
 		// replace thread ids in new instructions
-		if( tid != 0 ) //only in parfor context
+		if( ProgramBlock.isThreadID(tid) ) //only in parfor context
 			newInst = ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null, null, false, false);
 		
 		// remove writes if called through mlcontext or jmlc 
@@ -187,7 +187,7 @@ public class Recompiler
 		}
 		
 		// replace thread ids in new instructions
-		if( tid != 0 ) //only in parfor context
+		if( ProgramBlock.isThreadID(tid) ) //only in parfor context
 			newInst = ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null, null, false, false);
 		
 		// explain recompiled instructions
@@ -209,7 +209,7 @@ public class Recompiler
 		}
 		
 		// replace thread ids in new instructions
-		if( tid != 0 ) //only in parfor context
+		if( ProgramBlock.isThreadID(tid) ) //only in parfor context
 			newInst = ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null, null, false, false);
 		
 		// explain recompiled instructions
@@ -231,7 +231,7 @@ public class Recompiler
 		}
 		
 		// replace thread ids in new instructions
-		if( tid != 0 ) //only in parfor context
+		if( ProgramBlock.isThreadID(tid) ) //only in parfor context
 			newInst = ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null, null, false, false);
 		
 		// explain recompiled hops / instructions
@@ -253,7 +253,7 @@ public class Recompiler
 		}
 
 		// replace thread ids in new instructions
-		if( tid != 0 ) //only in parfor context
+		if( ProgramBlock.isThreadID(tid) ) //only in parfor context
 			newInst = ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null, null, false, false);
 		
 		// explain recompiled hops / instructions
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java
index fa25000..8bb7e1a 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -82,7 +82,7 @@ public abstract class Lop
 	public static final String PROCESS_PREFIX = "_p";
 	public static final String CP_ROOT_THREAD_ID = "_t0";
 	public static final String CP_CHILD_THREAD = "_t";
-	public static final double SAMPLE_FRACTION = 0.01;										// for row sampling in distributed frame meta operations
+	public static final double SAMPLE_FRACTION = 0.01; // for row sampling in distributed frame meta operations
 	
 	//special delimiters w/ extended ASCII characters to avoid collisions 
 	public static final String INSTRUCTION_DELIMITOR = "\u2021";
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
index 4f5ef85..5cde84e 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
@@ -82,14 +82,27 @@ public abstract class ProgramBlock implements ParseInfo
 		return _sb;
 	}
 
-	public void setStatementBlock( StatementBlock sb ){
+	public void setStatementBlock(StatementBlock sb){
 		_sb = sb;
 	}
 
-	public void setThreadID( long id ){
+	public void setThreadID(long id){
 		_tid = id;
 	}
 	
+	public boolean hasThreadID() {
+		return _tid != 0;
+	}
+	
+	public static boolean isThreadID (long tid) {
+		return tid != 0;
+	}
+	
+	public long getThreadID() {
+		return _tid;
+	}
+	
+	
 	/**
 	 * Get the list of child program blocks if nested;
 	 * otherwise this method returns null.
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index 9c1eac0..5d7feee 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -225,8 +225,10 @@ public class FunctionCallCPInstruction extends CPInstruction {
 		}
 
 		//update lineage cache with the functions outputs
-		if( DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() )
-			LineageCache.putValue(fpb.getOutputParams(), liInputs, _functionName, ec);
+		if( DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() ) {
+			LineageCache.putValue(fpb.getOutputParams(), 
+				liInputs, getCacheFunctionName(_functionName, fpb), ec);
+		}
 	}
 
 	@Override
@@ -249,7 +251,7 @@ public class FunctionCallCPInstruction extends CPInstruction {
 		//split current instruction
 		String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
 		if( parts[3].equals(pattern) )
-			parts[3] = replace;	
+			parts[3] = replace;
 		
 		//construct and set modified instruction
 		StringBuilder sb = new StringBuilder();
@@ -262,14 +264,25 @@ public class FunctionCallCPInstruction extends CPInstruction {
 	}
 	
 	private boolean reuseFunctionOutputs(LineageItem[] liInputs, FunctionProgramBlock fpb, ExecutionContext ec) {
+		//prepare lineage cache probing
+		String funcName = getCacheFunctionName(_functionName, fpb);
 		int numOutputs = Math.min(_boundOutputNames.size(), fpb.getOutputParams().size());
-		boolean reuse = LineageCache.reuse(_boundOutputNames, fpb.getOutputParams(), numOutputs, liInputs, _functionName, ec);
+		
+		//reuse of function outputs
+		boolean reuse = LineageCache.reuse(
+			_boundOutputNames, fpb.getOutputParams(), numOutputs, liInputs, funcName, ec);
 
+		//statistics maintenance
 		if (reuse && DMLScript.STATISTICS) {
 			//decrement the call count for this function
-			Statistics.maintainCPFuncCallStats(this.getExtendedOpcode());
+			Statistics.maintainCPFuncCallStats(getExtendedOpcode());
 			LineageCacheStatistics.incrementFuncHits();
 		}
 		return reuse;
 	}
+	
+	private static String getCacheFunctionName(String fname, FunctionProgramBlock fpb) {
+		return !fpb.hasThreadID() ? fname :
+			fname.substring(0, fname.lastIndexOf(Lop.CP_CHILD_THREAD+fpb.getThreadID()));
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 2741b70..0d93699 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -143,17 +143,18 @@ public class LineageCache
 			LineageItem li = new LineageItem(outNames.get(i), opcode, liInputs);
 			Entry e = null;
 			synchronized( _cache ) {
-				if (LineageCache.probe(li)) 
+				if (LineageCache.probe(li)) {
 					e = LineageCache.getIntern(li);
-				else
+				}
+				else {
 					//create a placeholder if no reuse to avoid redundancy
 					//(e.g., concurrent threads that try to start the computation)
 					putIntern(li, outParams.get(i).getDataType(), null, null, 0);
-					//FIXME: parfor - every thread gets different function names
+				}
 			}
 			//TODO: handling of recursive calls
 			
-			if (e != null && !e.isNullVal()) {
+			if ( e != null ) {
 				String boundVarName = outNames.get(i);
 				Data boundValue = null;
 				//convert to matrix object
@@ -164,8 +165,9 @@ public class LineageCache
 					((MatrixObject)boundValue).acquireModify(e.getMBValue());
 					((MatrixObject)boundValue).release();
 				}
-				else
+				else {
 					boundValue = e.getSOValue();
+				}
 
 				funcOutputs.put(boundVarName, boundValue);
 				LineageItem orig = e._origItem;
@@ -250,7 +252,7 @@ public class LineageCache
 	
 	public static void putValue(List<DataIdentifier> outputs, LineageItem[] liInputs, String name, ExecutionContext ec)
 	{
-		if( !LineageCacheConfig.isMultiLevelReuse())
+		if( !LineageCacheConfig.isMultiLevelReuse() )
 			return;
 
 		HashMap<LineageItem, LineageItem> FuncLIMap = new HashMap<>();
@@ -264,6 +266,8 @@ public class LineageCache
 				boundLI.resetVisitStatus();
 			if (boundLI == null 
 				|| !LineageCache.probe(li)
+				//TODO remove this brittle constraint (if the placeholder is removed
+				//it might crash threads that are already waiting for its results)
 				|| LineageItemUtils.containsRandDataGen(new HashSet<>(Arrays.asList(liInputs)), boundLI)) {
 				AllOutputsCacheable = false;
 			}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
index 98ad75e..9704797 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
@@ -122,6 +122,14 @@ public class LineageCacheStatistics {
 		// Total time spent compiling lineage rewrites.
 		_ctimeRewrite.add(delta);
 	}
+	
+	public static long getMultiLevelFnHits() {
+		return _numHitsFunc.longValue();
+	}
+	
+	public static long getMultiLevelSBHits() {
+		return _numHitsSB.longValue();
+	}
 
 	public static void incrementPRwExecTime(long delta) {
 		// Total time spent executing lineage rewrites.
@@ -138,7 +146,7 @@ public class LineageCacheStatistics {
 		return sb.toString();
 	}
 
-	public static String displayMultiLvlHits() {
+	public static String displayMultiLevelHits() {
 		StringBuilder sb = new StringBuilder();
 		sb.append(_numHitsInst.longValue());
 		sb.append("/");
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index eade225..aeeacd3 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -612,16 +612,12 @@ public class LineageItemUtils {
 	public static boolean containsRandDataGen(HashSet<LineageItem> entries, LineageItem root) {
 		if (entries.contains(root) || root.isVisited())
 			return false;
-
-		boolean isRand = false;
-		if (isNonDeterministic(root))
-			isRand |= true;
-		if (!root.isLeaf()) 
+		boolean isRand = isNonDeterministic(root);
+		if (!root.isLeaf() && !isRand) 
 			for (LineageItem input : root.getInputs())
-				isRand = isRand ? true : containsRandDataGen(entries, input);
+				isRand |= containsRandDataGen(entries, input);
 		root.setVisited();
 		return isRand;
-		//TODO: unmark for caching in compile time
 	}
 	
 	private static boolean isNonDeterministic(LineageItem li) {
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java
index 9445caf..4c3cbef 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -945,7 +945,7 @@ public class Statistics
 			}
 			if (DMLScript.LINEAGE && !ReuseCacheType.isNone()) {
 				sb.append("LinCache hits (Mem/FS/Del): \t" + LineageCacheStatistics.displayHits() + ".\n");
-				sb.append("LinCache MultiLevel (Ins/SB/Fn):" + LineageCacheStatistics.displayMultiLvlHits() + ".\n");
+				sb.append("LinCache MultiLevel (Ins/SB/Fn):" + LineageCacheStatistics.displayMultiLevelHits() + ".\n");
 				sb.append("LinCache writes (Mem/FS): \t" + LineageCacheStatistics.displayWtrites() + ".\n");
 				sb.append("LinCache FStimes (Rd/Wr): \t" + LineageCacheStatistics.displayTime() + " sec.\n");
 				sb.append("LinCache costing time:  \t" + LineageCacheStatistics.displayCostingTime() + " sec.\n");
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
index 22740f3..8fc7f78 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/FunctionFullReuseTest.java
@@ -19,13 +19,16 @@
 
 package org.apache.sysds.test.functions.lineage;
 
+import org.junit.Assert;
 import org.junit.Test;
+
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.recompile.Recompiler;
 import org.apache.sysds.lops.LopProperties.ExecType;
 import org.apache.sysds.runtime.lineage.Lineage;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
+import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
 import org.apache.sysds.runtime.matrix.data.MatrixValue;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
@@ -128,6 +131,10 @@ public class FunctionFullReuseTest extends AutomatedTestBase
 			Lineage.setLinReuseNone();
 			
 			TestUtils.compareMatrices(X_orig, X_reused, 1e-6, "Origin", "Reused");
+			if( testname.endsWith("6") ) { // parfor fn reuse
+				Assert.assertEquals(9L, LineageCacheStatistics.getMultiLevelFnHits() 
+					+ LineageCacheStatistics.getMultiLevelSBHits());
+			}
 		}
 		finally {
 			OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = old_simplification;
diff --git a/src/test/scripts/functions/lineage/FunctionFullReuse6.dml b/src/test/scripts/functions/lineage/FunctionFullReuse6.dml
index 2b025d5..02af351 100644
--- a/src/test/scripts/functions/lineage/FunctionFullReuse6.dml
+++ b/src/test/scripts/functions/lineage/FunctionFullReuse6.dml
@@ -20,9 +20,9 @@
 #-------------------------------------------------------------
 
 foo = function(Matrix[Double] X) return (Matrix[Double] R) {
-  y = X + X - 2 * sqrt(X) + X * X;
+  Y = X + X - 2 * sqrt(X) + X * X;
   while(FALSE){}
-  R = rowSums(y)*colSums(y);
+  R = rowSums(Y)%*%colSums(Y);
 }
 
 X = rand(rows=100, cols=10, seed=7);