You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2020/06/11 12:28:09 UTC

[systemml] branch master updated: [SYSTEMDS-412] Fix robustness lineage DAGs, parfor integration

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new e8c0a28  [SYSTEMDS-412] Fix robustness lineage DAGs, parfor integration
e8c0a28 is described below

commit e8c0a28c95b9a22f2a023715a3717c36528bd3ab
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Thu Jun 11 14:08:13 2020 +0200

    [SYSTEMDS-412] Fix robustness lineage DAGs, parfor integration
    
    This patch makes further robustness improvements to the handling of
    large lineage DAGs via non-recursive primitives. In this context,
    explain needed special treatment to preserve the previous output in DFS
    order w/ post-append.
    
    Furthermore, this also fixes a number of issues of the parfor
    integration such as (1) invalid cached hashes after sub-DAG replacement,
    (2) introduced cycles during parfor lineage merge, (3) steplm script
    improvements (disabled parfor dependency analysis was hiding the issue
    that introduced the cycles), and (4) some debugging functionality to
    reliably detect cycles in lineage DAGs.
---
 scripts/builtin/steplm.dml                         | 20 ++++----
 .../instructions/cp/DataGenCPInstruction.java      |  2 +-
 .../apache/sysds/runtime/lineage/LineageItem.java  |  5 ++
 .../sysds/runtime/lineage/LineageItemUtils.java    | 55 ++++++++++++++++++++--
 src/main/java/org/apache/sysds/utils/Explain.java  | 49 +++++++++++++++----
 .../test/functions/lineage/LineageReuseAlg.java    | 37 ++++++++++-----
 .../functions/lineage/LineageTraceParforSteplm.dml |  4 +-
 7 files changed, 134 insertions(+), 38 deletions(-)

diff --git a/scripts/builtin/steplm.dml b/scripts/builtin/steplm.dml
index 01f35ba..800c2ca 100644
--- a/scripts/builtin/steplm.dml
+++ b/scripts/builtin/steplm.dml
@@ -98,7 +98,7 @@ m_steplm = function(Matrix[Double] X, Matrix[Double] y, Integer icpt = 0,
 
   # First pass to examine single features
   AICs = matrix(0, 1, m_orig);
-  parfor (i in 1:m_orig, check = 0) {
+  parfor (i in 1:m_orig) {
     [AIC_1, beta_out_i] = linear_regression(X_orig[, i], y, icpt, reg, tol, maxi, verbose);
     AICs[1, i] = AIC_1;
     beta_out_all[1:nrow(beta_out_i), i] = beta_out_i;
@@ -129,25 +129,25 @@ m_steplm = function(Matrix[Double] X, Matrix[Double] y, Integer icpt = 0,
     while (continue) {
       # Subsequent passes over the features
       beta_out_all_2 = matrix(0, boa_ncol, m_orig * 1);
-      AICs = matrix(0, 1, m_orig); # full overwrite
-      parfor (i in 1:m_orig, check = 0) {
+      AICs_2 = matrix(0, 1, m_orig); # full overwrite
+      parfor (i in 1:m_orig) {
         if (as.scalar(columns_fixed[1, i]) == 0) {
           # Construct the feature matrix
-          X = cbind(X_global, X_orig[, i]);
-          [AIC_2, beta_out_i] = linear_regression(X, y, icpt, reg, tol, maxi, verbose);
-          AICs[1, i] = AIC_2;
-          beta_out_all_2[1:nrow(beta_out_i), i] = beta_out_i;
+          Xi = cbind(X_global, X_orig[, i]);
+          [AIC_2, beta_out_i2] = linear_regression(Xi, y, icpt, reg, tol, maxi, verbose);
+          AICs_2[1, i] = AIC_2;
+          beta_out_all_2[1:nrow(beta_out_i2), i] = beta_out_i2;
         }
         else {
-          AICs[1,i] = Inf;
+          AICs_2[1,i] = Inf;
         }
       }
 
       # Determine the best AIC
       AIC_best_orig = AIC_best;
-      AIC_best = min(min(AICs), AIC_best_orig);
+      AIC_best = min(min(AICs_2), AIC_best_orig);
       AIC_check = checkAIC(AIC_best, AIC_best_orig, thr);
-      column_best = ifelse(AIC_check, as.scalar(rowIndexMin(AICs)), column_best);
+      column_best = ifelse(AIC_check, as.scalar(rowIndexMin(AICs_2)), column_best);
 
       # have the best beta store in the matrix
       beta_best = beta_out_all_2[, column_best];
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
index baacca6..8d688b8 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/DataGenCPInstruction.java
@@ -402,7 +402,7 @@ public class DataGenCPInstruction extends UnaryCPInstruction {
 						tmpInstStr, position, String.valueOf(runtimeSeed)) : tmpInstStr;
 				}
 				//replace output variable name with a placeholder
-				//tmpInstStr = InstructionUtils.replaceOperandName(tmpInstStr);
+				tmpInstStr = InstructionUtils.replaceOperandName(tmpInstStr);
 				tmpInstStr = replaceNonLiteral(tmpInstStr, rows, 2, ec);
 				tmpInstStr = replaceNonLiteral(tmpInstStr, cols, 3, ec);
 				break;
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
index 38a4cb9..b936948 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItem.java
@@ -83,6 +83,11 @@ public class LineageItem {
 		return _inputs;
 	}
 	
+	public void setInput(int i, LineageItem item) {
+		_inputs[i] = item;
+		_hash = 0; //reset hash
+	}
+	
 	public String getData() {
 		return _data;
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index c49ba00..467bbc9 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -86,6 +86,8 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.Map;
+import java.util.Set;
+import java.util.Stack;
 import java.util.stream.Collectors;
 
 public class LineageItemUtils {
@@ -633,12 +635,39 @@ public class LineageItemUtils {
 	}
 	
 	public static LineageItem replace(LineageItem root, LineageItem liOld, LineageItem liNew) {
+		if( liNew == null )
+			throw new DMLRuntimeException("Invalid null lineage item for "+liOld.getId());
 		root.resetVisitStatusNR();
-		rReplace(root, liOld, liNew);
+		rReplaceNR(root, liOld, liNew);
 		root.resetVisitStatusNR();
 		return root;
 	}
 	
+	/**
+	 * Non-recursive equivalent of {@link #rReplace(LineageItem, LineageItem, LineageItem)} 
+	 * for robustness with regard to stack overflow errors.
+	 */
+	public static void rReplaceNR(LineageItem current, LineageItem liOld, LineageItem liNew) {
+		Stack<LineageItem> q = new Stack<>();
+		q.push(current);
+		while( !q.empty() ) {
+			LineageItem tmp = q.pop();
+			if( tmp.isVisited() || tmp.getInputs() == null )
+				continue;
+			//process children until old item found, then replace
+			for(int i=0; i<tmp.getInputs().length; i++) {
+				LineageItem ctmp = tmp.getInputs()[i];
+				if (liOld.getId() == ctmp.getId() && liOld.equals(ctmp))
+					tmp.setInput(i, liNew);
+				else
+					q.push(ctmp);
+			}
+			tmp.setVisited(true);
+		}
+	}
+	
+	@Deprecated
+	@SuppressWarnings("unused")
 	private static void rReplace(LineageItem current, LineageItem liOld, LineageItem liNew) {
 		if( current.isVisited() || current.getInputs() == null )
 			return;
@@ -648,7 +677,7 @@ public class LineageItemUtils {
 		for(int i=0; i<current.getInputs().length; i++) {
 			LineageItem tmp = current.getInputs()[i];
 			if (liOld.equals(tmp))
-				current.getInputs()[i] = liNew;
+				current.setInput(i, liNew);
 			else
 				rReplace(tmp, liOld, liNew);
 		}
@@ -671,7 +700,7 @@ public class LineageItemUtils {
 			if (li.isLeaf() && li.getType() != LineageItemType.Literal
 				&& li.getData().startsWith(LPLACEHOLDER))
 				//order-preserving replacement. IN#<xxx> represents relative position xxx
-				root.getInputs()[i] = newleaves[Integer.parseInt(li.getData().substring(3))];
+				root.setInput(i, newleaves[Integer.parseInt(li.getData().substring(3))]);
 			else
 				rReplaceDagLeaves(li, newleaves);
 		}
@@ -692,6 +721,25 @@ public class LineageItemUtils {
 		root.setVisited();
 	}
 	
+	public static void checkCycles(LineageItem current) {
+		current.resetVisitStatusNR();
+		rCheckCycles(current, new HashSet<Long>(), true);
+		current.resetVisitStatusNR();
+	}
+	
+	public static void rCheckCycles(LineageItem current, Set<Long> probe, boolean useObjIdent) {
+		if( current.isVisited() )
+			return;
+		long id = useObjIdent ? System.identityHashCode(current) : current.getId();
+		if( probe.contains(id) )
+			throw new DMLRuntimeException("Cycle detected for "+current.toString());
+		probe.add(id);
+		if( current.getInputs() != null )
+			for(LineageItem li : current.getInputs())
+				rCheckCycles(li, probe, useObjIdent);
+		current.setVisited();
+	}
+	
 	private static Hop[] createNaryInputs(LineageItem item, Map<Long, Hop> operands) {
 		int len = item.getInputs().length;
 		Hop[] ret = new Hop[len];
@@ -766,5 +814,4 @@ public class LineageItemUtils {
 		return(CPOpInputs != null ? LineageItemUtils.getLineage(ec, 
 			CPOpInputs.toArray(new CPOperand[CPOpInputs.size()])) : null);
 	}
-	
 }
diff --git a/src/main/java/org/apache/sysds/utils/Explain.java b/src/main/java/org/apache/sysds/utils/Explain.java
index d9e541e..9853e28 100644
--- a/src/main/java/org/apache/sysds/utils/Explain.java
+++ b/src/main/java/org/apache/sysds/utils/Explain.java
@@ -26,7 +26,9 @@ import java.util.HashSet;
 import java.util.Map;
 import java.util.List;
 import java.util.Map.Entry;
+import java.util.Stack;
 
+import org.apache.commons.lang3.mutable.MutableInt;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.LiteralOp;
 import org.apache.sysds.hops.OptimizerUtils;
@@ -339,7 +341,7 @@ public class Explain
 		StringBuilder sb = new StringBuilder();
 		LineageItem.resetVisitStatusNR(lis);
 		for( LineageItem li : lis )
-			sb.append(explainLineageItem(li, level));
+			sb.append(explainLineageItemNR(li, level));
 		LineageItem.resetVisitStatusNR(lis);
 		return sb.toString();
 	}
@@ -354,7 +356,7 @@ public class Explain
 
 	private static String explain( LineageItem li, int level ) {
 		li.resetVisitStatusNR();
-		String ret = explainLineageItem(li, level);
+		String ret = explainLineageItemNR(li, level);
 		li.resetVisitStatusNR();
 		return ret;
 	}
@@ -595,13 +597,42 @@ public class Explain
 		return sb.toString();
 	}
 
-	/**
-	 * Do a post-order traverse through the Lineage Item DAG and explain each Hop
-	 *
-	 * @param li lineage item
-	 * @param level offset
-	 * @return string explanation of Lineage Item DAG
-	 */
+	private static String explainLineageItemNR(LineageItem item, int level) {
+		//NOTE: in contrast to similar non-recursive functions like resetVisitStatusNR,
+		// we maintain a more complex stack to ensure DFS ordering where current nodes
+		// are added after the subtree underneath is processed (backwards compatibility)
+		Stack<LineageItem> stackItem = new Stack<>();
+		Stack<MutableInt> stackPos = new Stack<>();
+		stackItem.push(item); stackPos.push(new MutableInt(0));
+		StringBuilder sb = new StringBuilder();
+		while( !stackItem.empty() ) {
+			LineageItem tmpItem = stackItem.peek();
+			MutableInt tmpPos = stackPos.peek();
+			//check ascent condition - no item processing
+			if( tmpItem.isVisited() ) {
+				stackItem.pop(); stackPos.pop();
+			}
+			//check ascent condition - append item
+			else if( tmpItem.getInputs() == null 
+				|| tmpItem.getInputs().length <= tmpPos.intValue() ) {
+				sb.append(createOffset(level));
+				sb.append(tmpItem.toString());
+				sb.append('\n');
+				stackItem.pop(); stackPos.pop();
+				tmpItem.setVisited();
+			}
+			//check descent condition
+			else if( tmpItem.getInputs() != null ) {
+				stackItem.push(tmpItem.getInputs()[tmpPos.intValue()]);
+				tmpPos.increment();
+				stackPos.push(new MutableInt(0));
+			}
+		}
+		return sb.toString();
+	}
+	
+	@Deprecated
+	@SuppressWarnings("unused")
 	private static String explainLineageItem(LineageItem li, int level) {
 		if( li.isVisited())
 			return "";
diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
index 19b949d..5475b09 100644
--- a/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
+++ b/src/test/java/org/apache/sysds/test/functions/lineage/LineageReuseAlg.java
@@ -48,23 +48,38 @@ public class LineageReuseAlg extends AutomatedTestBase {
 		for( int i=1; i<=TEST_VARIANTS; i++ )
 			addTestConfiguration(TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i));
 	}
+
+	@Test
+	public void testStepLMHybrid() {
+		testLineageTrace(TEST_NAME+"1", ReuseCacheType.REUSE_HYBRID);
+	}
+	
+	@Test
+	public void testGridSearchLMHybrid() {
+		testLineageTrace(TEST_NAME+"2", ReuseCacheType.REUSE_HYBRID);
+	}
+
+	@Test
+	public void testMultiLogRegHybrid() {
+		testLineageTrace(TEST_NAME+"3", ReuseCacheType.REUSE_HYBRID);
+	}
 	
 	@Test
-	public void testStepLM() {
-		testLineageTrace(TEST_NAME+"1", ReuseCacheType.REUSE_HYBRID.name().toLowerCase());
+	public void testStepLMFull() {
+		testLineageTrace(TEST_NAME+"1", ReuseCacheType.REUSE_FULL);
 	}
 	
 	@Test
-	public void testGridSearchLM() {
-		testLineageTrace(TEST_NAME+"2", ReuseCacheType.REUSE_HYBRID.name().toLowerCase());
+	public void testGridSearchLMFull() {
+		testLineageTrace(TEST_NAME+"2", ReuseCacheType.REUSE_FULL);
 	}
 
 	@Test
-	public void testMultiLogReg() {
-		testLineageTrace(TEST_NAME+"3", ReuseCacheType.REUSE_HYBRID.name().toLowerCase());
+	public void testMultiLogRegFull() {
+		testLineageTrace(TEST_NAME+"3", ReuseCacheType.REUSE_FULL);
 	}
 
-	public void testLineageTrace(String testname, String reuseType) {
+	public void testLineageTrace(String testname, ReuseCacheType reuseType) {
 		boolean old_simplification = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
 		boolean old_sum_product = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
 		ExecMode platformOld = setExecMode(ExecType.CP);
@@ -85,7 +100,6 @@ public class LineageReuseAlg extends AutomatedTestBase {
 			proArgs.add("-args");
 			proArgs.add(output("X"));
 			programArgs = proArgs.toArray(new String[proArgs.size()]);
-			
 			Lineage.resetInternalState();
 			runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
 			HashMap<MatrixValue.CellIndex, Double> X_orig = readDMLMatrixFromHDFS("X");
@@ -93,17 +107,16 @@ public class LineageReuseAlg extends AutomatedTestBase {
 			// With lineage-based reuse enabled
 			proArgs.clear();
 			proArgs.add("-stats");
-			//proArgs.add("-explain");
 			proArgs.add("-lineage");
-			proArgs.add(reuseType);
+			proArgs.add(reuseType.name().toLowerCase());
 			proArgs.add("-args");
 			proArgs.add(output("X"));
 			programArgs = proArgs.toArray(new String[proArgs.size()]);
-			
 			Lineage.resetInternalState();
 			Lineage.setLinReuseFull();
 			runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
 			HashMap<MatrixValue.CellIndex, Double> X_reused = readDMLMatrixFromHDFS("X");
+			
 			Lineage.setLinReuseNone();
 			TestUtils.compareMatrices(X_orig, X_reused, 1e-6, "Origin", "Reused");
 		}
@@ -114,4 +127,4 @@ public class LineageReuseAlg extends AutomatedTestBase {
 			Recompiler.reinitRecompiler();
 		}
 	}
-}
\ No newline at end of file
+}
diff --git a/src/test/scripts/functions/lineage/LineageTraceParforSteplm.dml b/src/test/scripts/functions/lineage/LineageTraceParforSteplm.dml
index 576182b..96b03e4 100644
--- a/src/test/scripts/functions/lineage/LineageTraceParforSteplm.dml
+++ b/src/test/scripts/functions/lineage/LineageTraceParforSteplm.dml
@@ -21,6 +21,6 @@
 
 X = rand(rows=$2, cols=$3, seed=7);
 Y = rand(rows=nrow(X), cols=1, seed=2)
-X = steplm(X=X, y=Y)
+[B,S] = steplm(X=X, y=Y)
 
-write(X, $1);
+write(B, $1);