You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/09/13 16:52:25 UTC

[systemds] branch master updated: [SYSTEMDS-2641] Extended IPA rewrite handling (rebuild fgraph on demand)

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 57315fb  [SYSTEMDS-2641] Extended IPA rewrite handling (rebuild fgraph on demand)
57315fb is described below

commit 57315fb542af690fdf718686be2a6fd9296a4842
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sun Sep 13 18:40:10 2020 +0200

    [SYSTEMDS-2641] Extended IPA rewrite handling (rebuild fgraph on demand)
    
    This patch improves the inter-procedural analysis, which repeatedly
    propagates scalars and sizes, and applies various IPA rewrite passes.
    One of these rewrite passes is a second function inlining mechanism,
    which inlines functions that are called once or small functions with
    less than t=10 operators. However, this condition was based on a
    functional call graph which was never rebuilt.
    
    In slicefinder, after scalar propagation one of two calls to evalSlice
    gets removed (via remove unnecessary branches) but inlining did not take
    place due to mistakenly assumed two calls to this function. We now
    propagate the information of removed branches to IPA and rebuild the
    functional call graph if necessary.
---
 src/main/java/org/apache/sysds/hops/ipa/IPAPass.java     |  6 ++++--
 .../ipa/IPAPassApplyStaticAndDynamicHopRewrites.java     | 16 +++++++++++-----
 .../apache/sysds/hops/ipa/IPAPassEliminateDeadCode.java  |  3 ++-
 .../hops/ipa/IPAPassFlagFunctionsRecompileOnce.java      |  5 +++--
 .../apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java |  5 +++--
 .../sysds/hops/ipa/IPAPassForwardFunctionCalls.java      |  3 ++-
 .../apache/sysds/hops/ipa/IPAPassInlineFunctions.java    |  3 ++-
 .../sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java  |  3 ++-
 .../sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java   |  6 +++---
 .../hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java    |  3 ++-
 .../sysds/hops/ipa/IPAPassRemoveUnusedFunctions.java     |  3 ++-
 .../apache/sysds/hops/ipa/InterProceduralAnalysis.java   |  9 +++++++--
 .../org/apache/sysds/hops/rewrite/ProgramRewriter.java   |  6 ++++--
 .../test/functions/builtin/BuiltinSliceFinderTest.java   |  4 ++++
 14 files changed, 51 insertions(+), 24 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPass.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPass.java
index 7807a23..74a0b1d 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPass.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPass.java
@@ -29,7 +29,7 @@ import org.apache.sysds.parser.DMLProgram;
 public abstract class IPAPass 
 {
 	protected static final Log LOG = LogFactory.getLog(IPAPass.class.getName());
-    
+
 	/**
 	 * Indicates if an IPA pass is applicable for the current
 	 * configuration such as global flags or the chosen execution 
@@ -47,6 +47,8 @@ public abstract class IPAPass
 	 * @param prog dml program
 	 * @param fgraph function call graph
 	 * @param fcallSizes function call size infos
+	 * @return true if function call graph should be rebuild
 	 */
-	public abstract void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes );
+	public abstract boolean rewriteProgram( DMLProgram prog,
+		FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes );
 }
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassApplyStaticAndDynamicHopRewrites.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassApplyStaticAndDynamicHopRewrites.java
index c00b73e..fdd8af0 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassApplyStaticAndDynamicHopRewrites.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassApplyStaticAndDynamicHopRewrites.java
@@ -21,6 +21,7 @@ package org.apache.sysds.hops.ipa;
 
 
 import org.apache.sysds.hops.HopsException;
+import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
 import org.apache.sysds.hops.rewrite.ProgramRewriter;
 import org.apache.sysds.hops.rewrite.RewriteInjectSparkLoopCheckpointing;
 import org.apache.sysds.parser.DMLProgram;
@@ -42,17 +43,22 @@ public class IPAPassApplyStaticAndDynamicHopRewrites extends IPAPass
 	}
 	
 	@Override
-	public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) {
+	public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) {
 		try {
-			//construct rewriter w/o checkpoint injection to avoid redundancy
+			// construct rewriter w/o checkpoint injection to avoid redundancy
 			ProgramRewriter rewriter = new ProgramRewriter(
 				InterProceduralAnalysis.APPLY_STATIC_REWRITES,
 				InterProceduralAnalysis.APPLY_DYNAMIC_REWRITES);
 			rewriter.removeStatementBlockRewrite(RewriteInjectSparkLoopCheckpointing.class);
 			
-			//rewrite program hop dags and statement blocks
-			rewriter.rewriteProgramHopDAGs(prog, true); //rewrite and split
-		} 
+			// rewrite program hop dags and statement blocks
+			ProgramRewriteStatus status = new ProgramRewriteStatus();
+			rewriter.rewriteProgramHopDAGs(prog, true, status); //rewrite and split
+			// in case of removed branches entire function calls might have been eliminated,
+			// accordingly, we should rebuild the function call graph to allow for inlining
+			// even large functions, and avoid restrictions of scalar/size propagation
+			return status.getRemovedBranches();
+		}
 		catch (LanguageException ex) {
 			throw new HopsException(ex);
 		}
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassEliminateDeadCode.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassEliminateDeadCode.java
index 043c4b2..b3aae5e 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassEliminateDeadCode.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassEliminateDeadCode.java
@@ -52,7 +52,7 @@ public class IPAPassEliminateDeadCode extends IPAPass
 	}
 	
 	@Override
-	public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) {
+	public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) {
 		// step 1: backwards pass over main program to track used and remove unused vars
 		findAndRemoveDeadCode(prog.getStatementBlocks(), new HashSet<>(), fgraph);
 		
@@ -66,6 +66,7 @@ public class IPAPassEliminateDeadCode extends IPAPass
 			// backward pass over function to track used and remove unused vars
 			findAndRemoveDeadCode(fstmt.getBody(), usedVars, fgraph);
 		}
+		return false;
 	}
 	
 	private static void findAndRemoveDeadCode(List<StatementBlock> sbs, Set<String> usedVars, FunctionCallGraph fgraph) {
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
index c0cf6c7..b6351ba 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
@@ -53,10 +53,10 @@ public class IPAPassFlagFunctionsRecompileOnce extends IPAPass
 	}
 	
 	@Override
-	public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) 
+	public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) 
 	{
 		if( !ConfigurationManager.isDynamicRecompilation() )
-			return;
+			return false;
 		
 		try {
 			// flag applicable functions for recompile-once, note that this IPA pass
@@ -82,6 +82,7 @@ public class IPAPassFlagFunctionsRecompileOnce extends IPAPass
 		catch( LanguageException ex ) {
 			throw new HopsException(ex);
 		}
+		return false;
 	}
 	
 	/**
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
index a000096..6275f10 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
@@ -48,10 +48,10 @@ public class IPAPassFlagNonDeterminism extends IPAPass {
 	}
 
 	@Override
-	public void rewriteProgram (DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) 
+	public boolean rewriteProgram (DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) 
 	{
 		if (!LineageCacheConfig.isMultiLevelReuse())
-			return;
+			return false;
 		
 		try {
 			// Find the individual functions and statementblocks with non-determinism.
@@ -84,6 +84,7 @@ public class IPAPassFlagNonDeterminism extends IPAPass {
 		catch( LanguageException ex ) {
 			throw new HopsException(ex);
 		}
+		return false;
 	}
 
 	private boolean rIsNonDeterministicFnc (String fname, ArrayList<StatementBlock> sbs) 
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
index 1605524..8b57742 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
@@ -47,7 +47,7 @@ public class IPAPassForwardFunctionCalls extends IPAPass
 	}
 	
 	@Override
-	public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) 
+	public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) 
 	{
 		for( String fkey : fgraph.getReachableFunctions() ) {
 			FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fkey);
@@ -87,6 +87,7 @@ public class IPAPassForwardFunctionCalls extends IPAPass
 						+ fkey +"' with '"+call2.getFunctionKey()+"'");
 			}
 		}
+		return false;
 	}
 	
 	private static boolean singleFunctionOp(ArrayList<Hop> hops) {
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java
index 8c89689..3a465db 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java
@@ -53,7 +53,7 @@ public class IPAPassInlineFunctions extends IPAPass
 	}
 	
 	@Override
-	public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) 
+	public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) 
 	{
 		//NOTE: we inline single-statement-block (i.e., last-level block) functions
 		//that do not contain other functions, and either are small or called once
@@ -133,6 +133,7 @@ public class IPAPassInlineFunctions extends IPAPass
 				}
 			}
 		}
+		return false;
 	}
 	
 	private static boolean containsFunctionOp(ArrayList<Hop> hops) {
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java
index d9eac84..0f33a45 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java
@@ -56,7 +56,7 @@ public class IPAPassPropagateReplaceLiterals extends IPAPass
 	}
 	
 	@Override
-	public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) 
+	public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) 
 	{
 		//step 1: propagate final literals across main program
 		rReplaceLiterals(prog.getStatementBlocks(), prog, fgraph, fcallSizes);
@@ -93,6 +93,7 @@ public class IPAPassPropagateReplaceLiterals extends IPAPass
 				rReplaceLiterals(fstmt.getBody(), prog, fgraph, fcallSizes);
 			}
 		}
+		return false;
 	}
 	
 	private void rReplaceLiterals(List<StatementBlock> sbs, DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java
index 606c677..ffb7d68 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java
@@ -56,13 +56,12 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass
 	}
 	
 	@Override
-	public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) {
+	public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) {
 		//approach: scan over top-level program (guaranteed to be unconditional),
 		//collect ones=matrix(1,...); remove b(*)ones if not outer operation
 		HashMap<String, Hop> mOnes = new HashMap<>();
 		
-		for( StatementBlock sb : prog.getStatementBlocks() ) 
-		{
+		for( StatementBlock sb : prog.getStatementBlocks() )  {
 			//pruning updated variables
 			for( String var : sb.variablesUpdated().getVariableNames() )
 				if( mOnes.containsKey( var ) )
@@ -79,6 +78,7 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass
 				collectMatrixOfOnes(sb.getHops(), mOnes);
 			}
 		}
+		return false;
 	}
 	
 	private static void collectMatrixOfOnes(ArrayList<Hop> roots, HashMap<String,Hop> mOnes)
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
index 351d099..c78aac6 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
@@ -55,7 +55,7 @@ public class IPAPassRemoveUnnecessaryCheckpoints extends IPAPass
 	}
 	
 	@Override
-	public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) {
+	public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) {
 		//remove unnecessary checkpoint before update 
 		removeCheckpointBeforeUpdate(prog);
 		
@@ -64,6 +64,7 @@ public class IPAPassRemoveUnnecessaryCheckpoints extends IPAPass
 		
 		//remove unnecessary checkpoint read-{write|uagg}
 		removeCheckpointReadWrite(prog);
+		return false;
 	}
 	
 	private static void removeCheckpointBeforeUpdate(DMLProgram dmlp) {
diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnusedFunctions.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnusedFunctions.java
index 6d6abc8..9304926 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnusedFunctions.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnusedFunctions.java
@@ -44,7 +44,7 @@ public class IPAPassRemoveUnusedFunctions extends IPAPass
 	}
 	
 	@Override
-	public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) {
+	public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes ) {
 		try {
 			Set<String> fnamespaces = prog.getNamespaces().keySet();
 			for( String fnspace : fnamespaces  ) {
@@ -64,5 +64,6 @@ public class IPAPassRemoveUnusedFunctions extends IPAPass
 		catch(LanguageException ex) {
 			throw new HopsException(ex);
 		}
+		return false;
 	}
 }
diff --git a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index abb8b81..579af77 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -97,7 +97,7 @@ public class InterProceduralAnalysis
 	private final StatementBlock _sb;
 	
 	//function call graph for functions reachable from main
-	private final FunctionCallGraph _fgraph;
+	private FunctionCallGraph _fgraph;
 	
 	//set IPA passes to apply in order 
 	private final ArrayList<IPAPass> _passes;
@@ -200,9 +200,10 @@ public class InterProceduralAnalysis
 			}
 			
 			//step 2: apply additional IPA passes
+			boolean rebuildFGraph = false;
 			for( IPAPass pass : _passes )
 				if( pass.isApplicable(_fgraph) )
-					pass.rewriteProgram(_prog, _fgraph, fcallSizes);
+					rebuildFGraph |= pass.rewriteProgram(_prog, _fgraph, fcallSizes);
 			
 			//early abort without functions or on reached fixpoint
 			if( _fgraph.getReachableFunctions().isEmpty() 
@@ -212,6 +213,10 @@ public class InterProceduralAnalysis
 						+ " repetitions due to reached fixpoint.");
 				break;
 			}
+			
+			//step 3: rebuild function call graph if necessary
+			if( rebuildFGraph && i < repetitions-1 )
+				_fgraph = new FunctionCallGraph(_prog);
 		}
 		
 		//cleanup pass: remove unused functions
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index 87df183..af81e86 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -191,8 +191,10 @@ public class ProgramRewriter
 	}
 	
 	public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp, boolean splitDags) {
-		ProgramRewriteStatus state = new ProgramRewriteStatus();
-		
+		return rewriteProgramHopDAGs(dmlp, splitDags, new ProgramRewriteStatus());
+	}
+	
+	public ProgramRewriteStatus rewriteProgramHopDAGs(DMLProgram dmlp, boolean splitDags, ProgramRewriteStatus state) {
 		// for each namespace, handle function statement blocks
 		for (String namespaceKey : dmlp.getNamespaces().keySet())
 			for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSliceFinderTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSliceFinderTest.java
index ff9b639..a5dd9a7 100644
--- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSliceFinderTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinSliceFinderTest.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.test.functions.builtin;
 
 
+import org.junit.Assert;
 import org.junit.Test;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -110,6 +111,9 @@ public class BuiltinSliceFinderTest extends AutomatedTestBase {
 			double[][] ret = TestUtils.convertHashMapToDoubleArray(readDMLMatrixFromHDFS("R"));
 			for(int i=0; i<K; i++)
 				TestUtils.compareMatrices(EXPECTED_TOPK[i], ret[i], 1e-2);
+		
+			//ensure proper inlining, despite initially multiple calls and large function
+			Assert.assertFalse(heavyHittersContainsSubString("evalSlice"));
 		}
 		finally {
 			rtplatform = platformOld;