You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/06/14 20:01:57 UTC
systemml git commit: [SYSTEMML-1689] Basic refactoring of
inter-procedural analysis
Repository: systemml
Updated Branches:
refs/heads/master 0b698a42a -> e9fb7a028
[SYSTEMML-1689] Basic refactoring of inter-procedural analysis
This patch refactors the existing inter-procedural analysis into a core
pass for size propagation and additional well-defined IPA passes.
Furthermore, this extends the function call graph by additional
primitives for getting all function call operators, which is the basic
for a more fine-grained analysis for size and literal propagation.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e9fb7a02
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e9fb7a02
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e9fb7a02
Branch: refs/heads/master
Commit: e9fb7a028b2acf6ba2ce7209d18845074d02a082
Parents: 0b698a4
Author: Matthias Boehm <mb...@gmail.com>
Authored: Wed Jun 14 01:52:10 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Wed Jun 14 13:00:18 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/ipa/FunctionCallGraph.java | 68 ++-
.../java/org/apache/sysml/hops/ipa/IPAPass.java | 53 ++
.../ipa/IPAPassFlagFunctionsRecompileOnce.java | 117 ++++
.../ipa/IPAPassRemoveConstantBinaryOps.java | 163 +++++
.../IPAPassRemoveUnnecessaryCheckpoints.java | 296 +++++++++
.../hops/ipa/IPAPassRemoveUnusedFunctions.java | 70 +++
.../sysml/hops/ipa/InterProceduralAnalysis.java | 605 +++----------------
.../org/apache/sysml/parser/DMLTranslator.java | 6 +-
.../parfor/opt/OptimizationWrapper.java | 4 +-
9 files changed, 840 insertions(+), 542 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/e9fb7a02/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
index f872c97..9e55eaa 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallGraph.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.ipa;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Stack;
@@ -49,6 +50,10 @@ public class FunctionCallGraph
//(mapping from function keys to called function keys)
private final HashMap<String, HashSet<String>> _fGraph;
+ //program-wide function call operators per target function
+ //(mapping from function keys to set of its function calls)
+ private final HashMap<String, ArrayList<FunctionOp>> _fCalls;
+
//subset of direct or indirect recursive functions
private final HashSet<String> _fRecursive;
@@ -60,10 +65,25 @@ public class FunctionCallGraph
*/
public FunctionCallGraph(DMLProgram prog) {
_fGraph = new HashMap<String, HashSet<String>>();
+ _fCalls = new HashMap<String, ArrayList<FunctionOp>>();
_fRecursive = new HashSet<String>();
constructFunctionCallGraph(prog);
}
+
+ /**
+ * Constructs the function call graph for all functions
+ * reachable from the given statement block.
+ *
+ * @param sb statement block (potentially hierarchical)
+ */
+ public FunctionCallGraph(StatementBlock sb) {
+ _fGraph = new HashMap<String, HashSet<String>>();
+ _fCalls = new HashMap<String, ArrayList<FunctionOp>>();
+ _fRecursive = new HashSet<String>();
+
+ constructFunctionCallGraph(sb);
+ }
/**
* Returns all functions called from the given function.
@@ -89,6 +109,20 @@ public class FunctionCallGraph
}
/**
+ * Returns all function operators calling the given function.
+ *
+ * @param fkey function key of called function,
+ * null indicates the main program and returns an empty list
+ * @return list of function call hops
+ */
+ public Collection<FunctionOp> getFunctionCalls(String fkey) {
+ //main program cannot have function calls
+ if( fkey == null )
+ return Collections.emptyList();
+ return _fCalls.get(fkey);
+ }
+
+ /**
* Indicates if the given function is either directly or indirectly recursive.
* An example of an indirect recursive function is foo2 in the following call
* chain: foo1 -> foo2 -> foo1.
@@ -117,6 +151,16 @@ public class FunctionCallGraph
/**
* Returns all functions that are reachable either directly or indirectly
+ * form the main program, except the main program itself.
+ *
+ * @return list of function keys (namespace and name)
+ */
+ public Collection<String> getReachableFunctions() {
+ return getReachableFunctions(Collections.emptyList());
+ }
+
+ /**
+ * Returns all functions that are reachable either directly or indirectly
* form the main program, except the main program itself and the given
* blacklist of function names.
*
@@ -159,7 +203,7 @@ public class FunctionCallGraph
private void constructFunctionCallGraph(DMLProgram prog) {
if( !prog.hasFunctionStatementBlocks() )
return; //early abort if prog without functions
-
+
try {
Stack<String> fstack = new Stack<String>();
HashSet<String> lfset = new HashSet<String>();
@@ -172,6 +216,21 @@ public class FunctionCallGraph
}
}
+ private void constructFunctionCallGraph(StatementBlock sb) {
+ if( !sb.getDMLProg().hasFunctionStatementBlocks() )
+ return; //early abort if prog without functions
+
+ try {
+ Stack<String> fstack = new Stack<String>();
+ HashSet<String> lfset = new HashSet<String>();
+ _fGraph.put(MAIN_FUNCTION_KEY, new HashSet<String>());
+ rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, sb, fstack, lfset);
+ }
+ catch(HopsException ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+
private void rConstructFunctionCallGraph(String fkey, StatementBlock sb, Stack<String> fstack, HashSet<String> lfset)
throws HopsException
{
@@ -208,10 +267,15 @@ public class FunctionCallGraph
if( h instanceof FunctionOp ){
FunctionOp fop = (FunctionOp) h;
String lfkey = DMLProgram.constructFunctionKey(fop.getFunctionNamespace(), fop.getFunctionName());
+ //keep all function operators
+ if( !_fCalls.containsKey(lfkey) )
+ _fCalls.put(lfkey, new ArrayList<FunctionOp>());
+ _fCalls.get(lfkey).add(fop);
+
//prevent redundant call edges
if( lfset.contains(lfkey) || fop.getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE) )
continue;
-
+
if( !_fGraph.containsKey(lfkey) )
_fGraph.put(lfkey, new HashSet<String>());
http://git-wip-us.apache.org/repos/asf/systemml/blob/e9fb7a02/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java
new file mode 100644
index 0000000..cfd9df7
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPass.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.ipa;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.parser.DMLProgram;
+
+/**
+ * Base class for all IPA passes.
+ */
+public abstract class IPAPass
+{
+ protected static final Log LOG = LogFactory.getLog(IPAPass.class.getName());
+
+ /**
+ * Indicates if an IPA pass is applicable for the current
+ * configuration such as global flags or the chosen execution
+ * mode (e.g., hybrid_spark).
+ *
+ * @return true if applicable.
+ */
+ public abstract boolean isApplicable();
+
+ /**
+ * Rewrites the given program or its functions in place,
+ * with access to the read-only function call graph.
+ *
+ * @param prog dml program
+ * @param fgraph function call graph
+ * @throws HopsException
+ */
+ public abstract void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph )
+ throws HopsException;
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/e9fb7a02/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
new file mode 100644
index 0000000..ee072e4
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassFlagFunctionsRecompileOnce.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.ipa;
+
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.ForStatementBlock;
+import org.apache.sysml.parser.FunctionStatement;
+import org.apache.sysml.parser.FunctionStatementBlock;
+import org.apache.sysml.parser.IfStatement;
+import org.apache.sysml.parser.IfStatementBlock;
+import org.apache.sysml.parser.LanguageException;
+import org.apache.sysml.parser.StatementBlock;
+import org.apache.sysml.parser.WhileStatementBlock;
+
+/**
+ * This rewrite marks functions with loops as recompile once
+ * in order to reduce recompilation overhead. Such functions
+ * are recompiled on function entry with the size information
+ * of the function inputs which is often sufficient to decide
+ * upon execution types; in case there are still unknowns, the
+ * traditional recompilation per atomic block still applies.
+ *
+ * TODO call after lops construction
+ */
+public class IPAPassFlagFunctionsRecompileOnce extends IPAPass
+{
+ @Override
+ public boolean isApplicable() {
+ return InterProceduralAnalysis.FLAG_FUNCTION_RECOMPILE_ONCE;
+ }
+
+ @Override
+ public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph )
+ throws HopsException
+ {
+ try {
+ for (String namespaceKey : prog.getNamespaces().keySet())
+ for (String fname : prog.getFunctionStatementBlocks(namespaceKey).keySet())
+ {
+ FunctionStatementBlock fsblock = prog.getFunctionStatementBlock(namespaceKey,fname);
+ if( !fgraph.isRecursiveFunction(namespaceKey, fname) &&
+ rFlagFunctionForRecompileOnce( fsblock, false ) )
+ {
+ fsblock.setRecompileOnce( true );
+ if( LOG.isDebugEnabled() )
+ LOG.debug("IPA: FUNC flagged for recompile-once: " +
+ DMLProgram.constructFunctionKey(namespaceKey, fname));
+ }
+ }
+ }
+ catch( LanguageException ex ) {
+ throw new HopsException(ex);
+ }
+ }
+
+ /**
+ * Returns true if this statementblock requires recompilation inside a
+ * loop statement block.
+ *
+ * @param sb statement block
+ * @param inLoop true if in loop
+ * @return true if statement block requires recompilation inside a loop statement block
+ */
+ public boolean rFlagFunctionForRecompileOnce( StatementBlock sb, boolean inLoop )
+ {
+ boolean ret = false;
+
+ if (sb instanceof FunctionStatementBlock) {
+ FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
+ FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
+ for( StatementBlock c : fstmt.getBody() )
+ ret |= rFlagFunctionForRecompileOnce( c, inLoop );
+ }
+ else if (sb instanceof WhileStatementBlock) {
+ //recompilation information not available at this point
+ //hence, mark any loop statement block
+ ret = true;
+ }
+ else if (sb instanceof IfStatementBlock) {
+ IfStatementBlock isb = (IfStatementBlock) sb;
+ IfStatement istmt = (IfStatement)isb.getStatement(0);
+ ret |= (inLoop && isb.requiresPredicateRecompilation() );
+ for( StatementBlock c : istmt.getIfBody() )
+ ret |= rFlagFunctionForRecompileOnce( c, inLoop );
+ for( StatementBlock c : istmt.getElseBody() )
+ ret |= rFlagFunctionForRecompileOnce( c, inLoop );
+ }
+ else if (sb instanceof ForStatementBlock) {
+ //recompilation information not available at this point
+ //hence, mark any loop statement block
+ ret = true;
+ }
+ else {
+ ret |= ( inLoop && sb.requiresRecompilation() );
+ }
+
+ return ret;
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/e9fb7a02/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
new file mode 100644
index 0000000..c71ed45
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveConstantBinaryOps.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.ipa;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.DataGenOp;
+import org.apache.sysml.hops.DataOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.Hop.DataGenMethod;
+import org.apache.sysml.hops.Hop.DataOpTypes;
+import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.ForStatement;
+import org.apache.sysml.parser.ForStatementBlock;
+import org.apache.sysml.parser.IfStatement;
+import org.apache.sysml.parser.IfStatementBlock;
+import org.apache.sysml.parser.StatementBlock;
+import org.apache.sysml.parser.WhileStatement;
+import org.apache.sysml.parser.WhileStatementBlock;
+import org.apache.sysml.parser.Expression.DataType;
+
+/**
+ * This rewrite identifies binary operations with constant matrices
+ * such as X * ones, where ones might be created as a vector of ones
+ * before a loop. Such operations frequently occur after branch removal
+ * for fixed configurations or loss functions.
+ *
+ */
+public class IPAPassRemoveConstantBinaryOps extends IPAPass
+{
+ @Override
+ public boolean isApplicable() {
+ return InterProceduralAnalysis.REMOVE_CONSTANT_BINARY_OPS;
+ }
+
+ @Override
+ public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph )
+ throws HopsException
+ {
+ //approach: scan over top-level program (guaranteed to be unconditional),
+ //collect ones=matrix(1,...); remove b(*)ones if not outer operation
+ HashMap<String, Hop> mOnes = new HashMap<String, Hop>();
+
+ for( StatementBlock sb : prog.getStatementBlocks() )
+ {
+ //pruning updated variables
+ for( String var : sb.variablesUpdated().getVariableNames() )
+ if( mOnes.containsKey( var ) )
+ mOnes.remove( var );
+
+ //replace constant binary ops
+ if( !mOnes.isEmpty() )
+ rRemoveConstantBinaryOp(sb, mOnes);
+
+ //collect matrices of ones from last-level statement blocks
+ if( !(sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock
+ || sb instanceof ForStatementBlock) )
+ {
+ collectMatrixOfOnes(sb.get_hops(), mOnes);
+ }
+ }
+ }
+
+ private static void collectMatrixOfOnes(ArrayList<Hop> roots, HashMap<String,Hop> mOnes)
+ {
+ if( roots == null )
+ return;
+
+ for( Hop root : roots )
+ if( root instanceof DataOp && ((DataOp)root).getDataOpType()==DataOpTypes.TRANSIENTWRITE
+ && root.getInput().get(0) instanceof DataGenOp
+ && ((DataGenOp)root.getInput().get(0)).getOp()==DataGenMethod.RAND
+ && ((DataGenOp)root.getInput().get(0)).hasConstantValue(1.0))
+ {
+ mOnes.put(root.getName(),root.getInput().get(0));
+ }
+ }
+
+ private static void rRemoveConstantBinaryOp(StatementBlock sb, HashMap<String,Hop> mOnes)
+ throws HopsException
+ {
+ if( sb instanceof IfStatementBlock )
+ {
+ IfStatementBlock isb = (IfStatementBlock) sb;
+ IfStatement istmt = (IfStatement)isb.getStatement(0);
+ for( StatementBlock c : istmt.getIfBody() )
+ rRemoveConstantBinaryOp(c, mOnes);
+ if( istmt.getElseBody() != null )
+ for( StatementBlock c : istmt.getElseBody() )
+ rRemoveConstantBinaryOp(c, mOnes);
+ }
+ else if( sb instanceof WhileStatementBlock )
+ {
+ WhileStatementBlock wsb = (WhileStatementBlock) sb;
+ WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
+ for( StatementBlock c : wstmt.getBody() )
+ rRemoveConstantBinaryOp(c, mOnes);
+ }
+ else if( sb instanceof ForStatementBlock )
+ {
+ ForStatementBlock fsb = (ForStatementBlock) sb;
+ ForStatement fstmt = (ForStatement)fsb.getStatement(0);
+ for( StatementBlock c : fstmt.getBody() )
+ rRemoveConstantBinaryOp(c, mOnes);
+ }
+ else
+ {
+ if( sb.get_hops() != null ){
+ Hop.resetVisitStatus(sb.get_hops());
+ for( Hop hop : sb.get_hops() )
+ rRemoveConstantBinaryOp(hop, mOnes);
+ }
+ }
+ }
+
+ private static void rRemoveConstantBinaryOp(Hop hop, HashMap<String,Hop> mOnes)
+ {
+ if( hop.isVisited() )
+ return;
+
+ if( hop instanceof BinaryOp && ((BinaryOp)hop).getOp()==OpOp2.MULT
+ && !((BinaryOp) hop).isOuterVectorOperator()
+ && hop.getInput().get(0).getDataType()==DataType.MATRIX
+ && hop.getInput().get(1) instanceof DataOp
+ && mOnes.containsKey(hop.getInput().get(1).getName()) )
+ {
+ //replace matrix of ones with literal 1 (later on removed by
+ //algebraic simplification rewrites; otherwise more complex
+ //recursive processing of childs and rewiring required)
+ HopRewriteUtils.removeChildReferenceByPos(hop, hop.getInput().get(1), 1);
+ HopRewriteUtils.addChildReference(hop, new LiteralOp(1), 1);
+ }
+
+ //recursively process child nodes
+ for( Hop c : hop.getInput() )
+ rRemoveConstantBinaryOp(c, mOnes);
+
+ hop.setVisited();
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/e9fb7a02/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
new file mode 100644
index 0000000..20c47da
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
@@ -0,0 +1,296 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.ipa;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.DataOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.hops.Hop.DataOpTypes;
+import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.ForStatementBlock;
+import org.apache.sysml.parser.IfStatementBlock;
+import org.apache.sysml.parser.StatementBlock;
+import org.apache.sysml.parser.WhileStatementBlock;
+
+/**
+ * This rewrite identifies and removes unnecessary checkpoints, i.e.,
+ * persisting of Spark RDDs into a given storage level. For example,
+ * in chains such as pread-checkpoint-append-checkpoint, the first
+ * checkpoint is not used and creates unnecessary memory pressure.
+ *
+ */
+public class IPAPassRemoveUnnecessaryCheckpoints extends IPAPass
+{
+ @Override
+ public boolean isApplicable() {
+ return InterProceduralAnalysis.REMOVE_UNNECESSARY_CHECKPOINTS
+ && OptimizerUtils.isSparkExecutionMode();
+ }
+
+ @Override
+ public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph )
+ throws HopsException
+ {
+ //remove unnecessary checkpoint before update
+ removeCheckpointBeforeUpdate(prog);
+
+ //move necessary checkpoint after update
+ moveCheckpointAfterUpdate(prog);
+
+ //remove unnecessary checkpoint read-{write|uagg}
+ removeCheckpointReadWrite(prog);
+ }
+
+ private static void removeCheckpointBeforeUpdate(DMLProgram dmlp)
+ throws HopsException
+ {
+ //approach: scan over top-level program (guaranteed to be unconditional),
+ //collect checkpoints; determine if used before update; remove first checkpoint
+ //on second checkpoint if update in between and not used before update
+
+ HashMap<String, Hop> chkpointCand = new HashMap<String, Hop>();
+
+ for( StatementBlock sb : dmlp.getStatementBlocks() )
+ {
+ //prune candidates (used before updated)
+ Set<String> cands = new HashSet<String>(chkpointCand.keySet());
+ for( String cand : cands )
+ if( sb.variablesRead().containsVariable(cand)
+ && !sb.variablesUpdated().containsVariable(cand) )
+ {
+ //note: variableRead might include false positives due to meta
+ //data operations like nrow(X) or operations removed by rewrites
+ //double check hops on basic blocks; otherwise worst-case
+ boolean skipRemove = false;
+ if( sb.get_hops() !=null ) {
+ Hop.resetVisitStatus(sb.get_hops());
+ skipRemove = true;
+ for( Hop root : sb.get_hops() )
+ skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false);
+ }
+ if( !skipRemove )
+ chkpointCand.remove(cand);
+ }
+
+ //prune candidates (updated in conditional control flow)
+ Set<String> cands2 = new HashSet<String>(chkpointCand.keySet());
+ if( sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock
+ || sb instanceof ForStatementBlock )
+ {
+ for( String cand : cands2 )
+ if( sb.variablesUpdated().containsVariable(cand) ) {
+ chkpointCand.remove(cand);
+ }
+ }
+ //prune candidates (updated w/ multiple reads)
+ else
+ {
+ for( String cand : cands2 )
+ if( sb.variablesUpdated().containsVariable(cand) && sb.get_hops() != null)
+ {
+ Hop.resetVisitStatus(sb.get_hops());
+ for( Hop root : sb.get_hops() )
+ if( root.getName().equals(cand) &&
+ !HopRewriteUtils.rHasSimpleReadChain(root, cand) ) {
+ chkpointCand.remove(cand);
+ }
+ }
+ }
+
+ //collect checkpoints and remove unnecessary checkpoints
+ ArrayList<Hop> tmp = collectCheckpoints(sb.get_hops());
+ for( Hop chkpoint : tmp ) {
+ if( chkpointCand.containsKey(chkpoint.getName()) ) {
+ chkpointCand.get(chkpoint.getName()).setRequiresCheckpoint(false);
+ }
+ chkpointCand.put(chkpoint.getName(), chkpoint);
+ }
+
+ }
+ }
+
+ private static void moveCheckpointAfterUpdate(DMLProgram dmlp)
+ throws HopsException
+ {
+ //approach: scan over top-level program (guaranteed to be unconditional),
+ //collect checkpoints; determine if used before update; move first checkpoint
+ //after update if not used before update (best effort move which often avoids
+ //the second checkpoint on loops even though used in between)
+
+ HashMap<String, Hop> chkpointCand = new HashMap<String, Hop>();
+
+ for( StatementBlock sb : dmlp.getStatementBlocks() )
+ {
+ //prune candidates (used before updated)
+ Set<String> cands = new HashSet<String>(chkpointCand.keySet());
+ for( String cand : cands )
+ if( sb.variablesRead().containsVariable(cand)
+ && !sb.variablesUpdated().containsVariable(cand) )
+ {
+ //note: variableRead might include false positives due to meta
+ //data operations like nrow(X) or operations removed by rewrites
+ //double check hops on basic blocks; otherwise worst-case
+ boolean skipRemove = false;
+ if( sb.get_hops() !=null ) {
+ Hop.resetVisitStatus(sb.get_hops());
+ skipRemove = true;
+ for( Hop root : sb.get_hops() )
+ skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false);
+ }
+ if( !skipRemove )
+ chkpointCand.remove(cand);
+ }
+
+ //prune candidates (updated in conditional control flow)
+ Set<String> cands2 = new HashSet<String>(chkpointCand.keySet());
+ if( sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock
+ || sb instanceof ForStatementBlock )
+ {
+ for( String cand : cands2 )
+ if( sb.variablesUpdated().containsVariable(cand) ) {
+ chkpointCand.remove(cand);
+ }
+ }
+ //move checkpoint after update with simple read chain
+ //(note: right now this only applies if the checkpoints comes from a previous
+ //statement block, within-dag checkpoints should be handled during injection)
+ else
+ {
+ for( String cand : cands2 )
+ if( sb.variablesUpdated().containsVariable(cand) && sb.get_hops() != null) {
+ Hop.resetVisitStatus(sb.get_hops());
+ for( Hop root : sb.get_hops() )
+ if( root.getName().equals(cand) ) {
+ if( HopRewriteUtils.rHasSimpleReadChain(root, cand) ) {
+ chkpointCand.get(cand).setRequiresCheckpoint(false);
+ root.getInput().get(0).setRequiresCheckpoint(true);
+ chkpointCand.put(cand, root.getInput().get(0));
+ }
+ else
+ chkpointCand.remove(cand);
+ }
+ }
+ }
+
+ //collect checkpoints
+ ArrayList<Hop> tmp = collectCheckpoints(sb.get_hops());
+ for( Hop chkpoint : tmp ) {
+ chkpointCand.put(chkpoint.getName(), chkpoint);
+ }
+ }
+ }
+
+ private static void removeCheckpointReadWrite(DMLProgram dmlp)
+ throws HopsException
+ {
+ List<StatementBlock> sbs = dmlp.getStatementBlocks();
+
+ if( sbs.size()==1 & !(sbs.get(0) instanceof IfStatementBlock
+ || sbs.get(0) instanceof WhileStatementBlock
+ || sbs.get(0) instanceof ForStatementBlock) )
+ {
+ //recursively process all dag roots
+ if( sbs.get(0).get_hops()!=null ) {
+ Hop.resetVisitStatus(sbs.get(0).get_hops());
+ for( Hop root : sbs.get(0).get_hops() )
+ rRemoveCheckpointReadWrite(root);
+ }
+ }
+ }
+
+ private static ArrayList<Hop> collectCheckpoints(ArrayList<Hop> roots)
+ {
+ ArrayList<Hop> ret = new ArrayList<Hop>();
+ if( roots != null ) {
+ Hop.resetVisitStatus(roots);
+ for( Hop root : roots )
+ rCollectCheckpoints(root, ret);
+ }
+
+ return ret;
+ }
+
+ private static void rCollectCheckpoints(Hop hop, ArrayList<Hop> checkpoints)
+ {
+ if( hop.isVisited() )
+ return;
+
+ //handle leaf node for variable (checkpoint directly bound
+ //to logical variable name and not used)
+ if( hop.requiresCheckpoint() && hop.getParent().size()==1
+ && hop.getParent().get(0) instanceof DataOp
+ && ((DataOp)hop.getParent().get(0)).getDataOpType()==DataOpTypes.TRANSIENTWRITE)
+ {
+ checkpoints.add(hop);
+ }
+
+ //recursively process child nodes
+ for( Hop c : hop.getInput() )
+ rCollectCheckpoints(c, checkpoints);
+
+ hop.setVisited();
+ }
+
+ public static void rRemoveCheckpointReadWrite(Hop hop)
+ {
+ if( hop.isVisited() )
+ return;
+
+ //remove checkpoint on pread if only consumed by pwrite or uagg
+ if( (hop instanceof DataOp && ((DataOp)hop).getDataOpType()==DataOpTypes.PERSISTENTWRITE)
+ || hop instanceof AggUnaryOp )
+ {
+ //(pwrite|uagg) - pread
+ Hop c0 = hop.getInput().get(0);
+ if( c0.requiresCheckpoint() && c0.getParent().size() == 1
+ && c0 instanceof DataOp && ((DataOp)c0).getDataOpType()==DataOpTypes.PERSISTENTREAD )
+ {
+ c0.setRequiresCheckpoint(false);
+ }
+
+ //(pwrite|uagg) - frame/matri cast - pread
+ if( c0 instanceof UnaryOp && c0.getParent().size() == 1
+ && (((UnaryOp)c0).getOp()==OpOp1.CAST_AS_FRAME || ((UnaryOp)c0).getOp()==OpOp1.CAST_AS_MATRIX )
+ && c0.getInput().get(0).requiresCheckpoint() && c0.getInput().get(0).getParent().size() == 1
+ && c0.getInput().get(0) instanceof DataOp
+ && ((DataOp)c0.getInput().get(0)).getDataOpType()==DataOpTypes.PERSISTENTREAD )
+ {
+ c0.getInput().get(0).setRequiresCheckpoint(false);
+ }
+ }
+
+ //recursively process children
+ for( Hop c : hop.getInput() )
+ rRemoveCheckpointReadWrite(c);
+
+ hop.setVisited();
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/e9fb7a02/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java
new file mode 100644
index 0000000..f713b7b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/ipa/IPAPassRemoveUnusedFunctions.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.ipa;
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Set;
+import java.util.Map.Entry;
+
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.parser.DMLProgram;
+import org.apache.sysml.parser.FunctionStatementBlock;
+import org.apache.sysml.parser.LanguageException;
+
+/**
+ * This rewrite identifies and removes unused functions in order
+ * to reduce compilation overhead and other overheads such as
+ * parfor worker creation, where we construct function copies.
+ *
+ */
+public class IPAPassRemoveUnusedFunctions extends IPAPass
+{
+ @Override
+ public boolean isApplicable() {
+ return InterProceduralAnalysis.REMOVE_UNUSED_FUNCTIONS;
+ }
+
+ @Override
+ public void rewriteProgram( DMLProgram prog, FunctionCallGraph fgraph )
+ throws HopsException
+ {
+ try {
+ Set<String> fnamespaces = prog.getNamespaces().keySet();
+ for( String fnspace : fnamespaces ) {
+ HashMap<String, FunctionStatementBlock> fsbs = prog.getFunctionStatementBlocks(fnspace);
+ Iterator<Entry<String, FunctionStatementBlock>> iter = fsbs.entrySet().iterator();
+ while( iter.hasNext() ) {
+ Entry<String, FunctionStatementBlock> e = iter.next();
+ if( !fgraph.isReachableFunction(fnspace, e.getKey()) ) {
+ iter.remove();
+ if( LOG.isDebugEnabled() )
+ LOG.debug("IPA: Removed unused function: " +
+ DMLProgram.constructFunctionKey(fnspace, e.getKey()));
+ }
+ }
+ }
+ }
+ catch(LanguageException ex) {
+ throw new HopsException(ex);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/e9fb7a02/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
index 3562c9f..0602208 100644
--- a/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysml/hops/ipa/InterProceduralAnalysis.java
@@ -23,8 +23,6 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
-import java.util.Iterator;
-import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
@@ -34,21 +32,13 @@ import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysml.conf.ConfigurationManager;
-import org.apache.sysml.hops.AggUnaryOp;
-import org.apache.sysml.hops.BinaryOp;
-import org.apache.sysml.hops.DataGenOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.FunctionOp;
import org.apache.sysml.hops.FunctionOp.FunctionType;
import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.Hop.DataGenMethod;
-import org.apache.sysml.hops.Hop.DataOpTypes;
-import org.apache.sysml.hops.Hop.OpOp1;
-import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.LiteralOp;
-import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.parser.DMLProgram;
@@ -89,28 +79,10 @@ import org.apache.sysml.udf.lib.OrderWrapper;
* changing sparsity etc (that requires the rewritten hops dag as input). This
* also includes control-flow aware propagation of size and sparsity. Furthermore,
* it also serves as a second constant propagation pass.
- *
- * In general, the basic concepts of IPA are as follows and all places that deal with
- * statistic propagation should adhere to that:
- * * Rule 1: Exact size propagation: Since the dimension information are sometimes used
- * for specific lops construction (e.g., in append) and rewrites, we cannot propagate worst-case
- * estimates but only exact information; otherwise size must be unknown.
- * * Rule 2: Dimension information and sparsity are handled separately, i.e., if an updated
- * variable has changing sparsity but constant dimensions, its dimensions are known but
- * sparsity unknown.
- *
- * More specifically, those two rules are currently realized as follows:
- * * Statistics propagation is applied for DML-bodied functions that are invoked exactly once.
- * This ensures that we can savely propagate exact information into this function.
- * If ALLOW_MULTIPLE_FUNCTION_CALLS is enabled we treat multiple calls with the same sizes
- * as one call and hence, propagate those statistics into the function as well.
- * * Output size inference happens for DML-bodied functions that are invoked exactly once
- * and for external functions that are known in advance (see UDFs in org.apache.sysml.udf).
- * * Size propagation across DAGs requires control flow awareness:
- * - Generic statement blocks: updated variables → old stats in; new stats out
- * - While/for statement blocks: updated variables → old stats in/out if loop insensitive; otherwise unknown
- * - If statement blocks: updated variables → old stats in; new stats out if branch-insensitive
- *
+ *
+ * Additionally, IPA also covers the removal of unused functions, the decision on
+ * recompile once functions, the removal of unnecessary checkpoints, and the
+ * global removal of constant binary operations such as X * ones.
*
*/
@SuppressWarnings("deprecation")
@@ -120,14 +92,14 @@ public class InterProceduralAnalysis
private static final Log LOG = LogFactory.getLog(InterProceduralAnalysis.class.getName());
//internal configuration parameters
- private static final boolean INTRA_PROCEDURAL_ANALYSIS = true; //propagate statistics across statement blocks (main/functions)
- private static final boolean PROPAGATE_KNOWN_UDF_STATISTICS = true; //propagate statistics for known external functions
- private static final boolean ALLOW_MULTIPLE_FUNCTION_CALLS = true; //propagate consistent statistics from multiple calls
- private static final boolean REMOVE_UNUSED_FUNCTIONS = true; //remove unused functions (inlined or never called)
- private static final boolean FLAG_FUNCTION_RECOMPILE_ONCE = true; //flag functions which require recompilation inside a loop for full function recompile
- private static final boolean REMOVE_UNNECESSARY_CHECKPOINTS = true; //remove unnecessary checkpoints (unconditionally overwritten intermediates)
- private static final boolean REMOVE_CONSTANT_BINARY_OPS = true; //remove constant binary operations (e.g., X*ones, where ones=matrix(1,...))
- private static final boolean PROPAGATE_SCALAR_VARS_INTO_FUN = true; //propagate scalar variables into functions that are called once
+ protected static final boolean INTRA_PROCEDURAL_ANALYSIS = true; //propagate statistics across statement blocks (main/functions)
+ protected static final boolean PROPAGATE_KNOWN_UDF_STATISTICS = true; //propagate statistics for known external functions
+ protected static final boolean ALLOW_MULTIPLE_FUNCTION_CALLS = true; //propagate consistent statistics from multiple calls
+ protected static final boolean REMOVE_UNUSED_FUNCTIONS = true; //remove unused functions (inlined or never called)
+ protected static final boolean FLAG_FUNCTION_RECOMPILE_ONCE = true; //flag functions which require recompilation inside a loop for full function recompile
+ protected static final boolean REMOVE_UNNECESSARY_CHECKPOINTS = true; //remove unnecessary checkpoints (unconditionally overwritten intermediates)
+ protected static final boolean REMOVE_CONSTANT_BINARY_OPS = true; //remove constant binary operations (e.g., X*ones, where ones=matrix(1,...))
+ protected static final boolean PROPAGATE_SCALAR_VARS_INTO_FUN = true; //propagate scalar variables into functions that are called once
public static boolean UNARY_DIMS_PRESERVING_FUNS = true; //determine and exploit unary dimension preserving functions
static {
@@ -138,8 +110,46 @@ public class InterProceduralAnalysis
}
}
- public InterProceduralAnalysis() {
- //do nothing
+ private final DMLProgram _prog;
+ private final StatementBlock _sb;
+
+ //function call graph for functions reachable from main
+ private final FunctionCallGraph _fgraph;
+
+ //set IPA passes to apply in order
+ private final ArrayList<IPAPass> _passes;
+
+ /**
+ * Creates a handle for performing inter-procedural analysis
+ * for a given DML program and its associated HOP DAGs. This
+ * call initializes various internal information such as the
+ * function call graph and auxiliary function call information
+ * which can be reused across multiple IPA calls (e.g., for
+ * second chance analysis).
+ *
+ */
+ public InterProceduralAnalysis(DMLProgram dmlp) {
+ //analyzes the function call graph
+ _prog = dmlp;
+ _sb = null;
+ _fgraph = new FunctionCallGraph(dmlp);
+
+ //create order list of IPA passes
+ _passes = new ArrayList<IPAPass>();
+ _passes.add(new IPAPassRemoveUnusedFunctions());
+ _passes.add(new IPAPassFlagFunctionsRecompileOnce());
+ _passes.add(new IPAPassRemoveUnnecessaryCheckpoints());
+ _passes.add(new IPAPassRemoveConstantBinaryOps());
+ }
+
+ public InterProceduralAnalysis(StatementBlock sb) {
+ //analyzes the function call graph
+ _prog = sb.getDMLProg();
+ _sb = sb;
+ _fgraph = new FunctionCallGraph(sb);
+
+ //create order list of IPA passes
+ _passes = new ArrayList<IPAPass>();
}
/**
@@ -150,29 +160,30 @@ public class InterProceduralAnalysis
* @throws ParseException if ParseException occurs
* @throws LanguageException if LanguageException occurs
*/
- public void analyzeProgram( DMLProgram dmlp )
+ public void analyzeProgram()
throws HopsException, ParseException, LanguageException
{
- FunctionCallGraph fgraph = new FunctionCallGraph(dmlp);
+ //TODO move main IPA into separate IPA pass for size propagation
+ //together with rework of candidate selection
//step 1: get candidates for statistics propagation into functions (if required)
Map<String, Integer> fcandCounts = new HashMap<String, Integer>();
Map<String, FunctionOp> fcandHops = new HashMap<String, FunctionOp>();
Map<String, Set<Long>> fcandSafeNNZ = new HashMap<String, Set<Long>>();
- if( !dmlp.getFunctionStatementBlocks().isEmpty() ) {
- for ( StatementBlock sb : dmlp.getStatementBlocks() ) //get candidates (over entire program)
+ if( !_prog.getFunctionStatementBlocks().isEmpty() ) {
+ for ( StatementBlock sb : _prog.getStatementBlocks() ) //get candidates (over entire program)
getFunctionCandidatesForStatisticPropagation( sb, fcandCounts, fcandHops );
pruneFunctionCandidatesForStatisticPropagation( fcandCounts, fcandHops );
determineFunctionCandidatesNNZPropagation( fcandHops, fcandSafeNNZ );
- DMLTranslator.resetHopsDAGVisitStatus( dmlp );
+ DMLTranslator.resetHopsDAGVisitStatus( _prog );
}
//step 2: get unary dimension-preserving non-candidate functions
- Collection<String> unaryFcandTmp = fgraph.getReachableFunctions(fcandCounts.keySet());
+ Collection<String> unaryFcandTmp = _fgraph.getReachableFunctions(fcandCounts.keySet());
HashSet<String> unaryFcands = new HashSet<String>();
if( !unaryFcandTmp.isEmpty() && UNARY_DIMS_PRESERVING_FUNS ) {
for( String tmp : unaryFcandTmp )
- if( isUnarySizePreservingFunction(dmlp.getFunctionStatementBlock(tmp)) )
+ if( isUnarySizePreservingFunction(_prog.getFunctionStatementBlock(tmp)) )
unaryFcands.add(tmp);
}
@@ -180,61 +191,37 @@ public class InterProceduralAnalysis
if( !fcandCounts.isEmpty() || INTRA_PROCEDURAL_ANALYSIS ) {
//(callVars used to chain outputs/inputs of multiple functions calls)
LocalVariableMap callVars = new LocalVariableMap();
- for ( StatementBlock sb : dmlp.getStatementBlocks() ) //propagate stats into candidates
+ for ( StatementBlock sb : _prog.getStatementBlocks() ) //propagate stats into candidates
propagateStatisticsAcrossBlock( sb, fcandCounts, callVars, fcandSafeNNZ, unaryFcands, new HashSet<String>() );
}
- //step 4: remove unused functions (e.g., inlined or never called)
- if( REMOVE_UNUSED_FUNCTIONS ) {
- removeUnusedFunctions( dmlp, fgraph );
- }
-
- //step 5: flag functions with loops for 'recompile-on-entry'
- if( FLAG_FUNCTION_RECOMPILE_ONCE ) {
- flagFunctionsForRecompileOnce( dmlp, fgraph );
- }
-
- //step 6: set global data flow properties
- if( REMOVE_UNNECESSARY_CHECKPOINTS
- && OptimizerUtils.isSparkExecutionMode() )
- {
- //remove unnecessary checkpoint before update
- removeCheckpointBeforeUpdate(dmlp);
-
- //move necessary checkpoint after update
- moveCheckpointAfterUpdate(dmlp);
-
- //remove unnecessary checkpoint read-{write|uagg}
- removeCheckpointReadWrite(dmlp);
- }
-
- //step 7: remove constant binary ops
- if( REMOVE_CONSTANT_BINARY_OPS ) {
- removeConstantBinaryOps(dmlp);
- }
+ //step 4: apply additional IPA passes
+ for( IPAPass pass : _passes )
+ if( pass.isApplicable() )
+ pass.rewriteProgram(_prog, _fgraph);
}
- public Set<String> analyzeSubProgram( StatementBlock sb )
+ public Set<String> analyzeSubProgram()
throws HopsException, ParseException
{
- DMLTranslator.resetHopsDAGVisitStatus(sb);
+ DMLTranslator.resetHopsDAGVisitStatus(_sb);
//step 1: get candidates for statistics propagation into functions (if required)
Map<String, Integer> fcandCounts = new HashMap<String, Integer>();
Map<String, FunctionOp> fcandHops = new HashMap<String, FunctionOp>();
Map<String, Set<Long>> fcandSafeNNZ = new HashMap<String, Set<Long>>();
Set<String> allFCandKeys = new HashSet<String>();
- getFunctionCandidatesForStatisticPropagation( sb, fcandCounts, fcandHops );
+ getFunctionCandidatesForStatisticPropagation( _sb, fcandCounts, fcandHops );
allFCandKeys.addAll(fcandCounts.keySet()); //cp before pruning
pruneFunctionCandidatesForStatisticPropagation( fcandCounts, fcandHops );
determineFunctionCandidatesNNZPropagation( fcandHops, fcandSafeNNZ );
- DMLTranslator.resetHopsDAGVisitStatus( sb );
+ DMLTranslator.resetHopsDAGVisitStatus( _sb );
if( !fcandCounts.isEmpty() ) {
//step 2: propagate statistics into functions and across DAGs
//(callVars used to chain outputs/inputs of multiple functions calls)
LocalVariableMap callVars = new LocalVariableMap();
- propagateStatisticsAcrossBlock( sb, fcandCounts, callVars, fcandSafeNNZ, new HashSet<String>(), new HashSet<String>() );
+ propagateStatisticsAcrossBlock( _sb, fcandCounts, callVars, fcandSafeNNZ, new HashSet<String>(), new HashSet<String>() );
}
return fcandCounts.keySet();
@@ -937,456 +924,4 @@ public class InterProceduralAnalysis
return moOut;
}
-
- /////////////////////////////
- // REMOVE UNUSED FUNCTIONS
- //////
-
- public void removeUnusedFunctions( DMLProgram dmlp, FunctionCallGraph fgraph )
- throws LanguageException
- {
- Set<String> fnamespaces = dmlp.getNamespaces().keySet();
- for( String fnspace : fnamespaces ) {
- HashMap<String, FunctionStatementBlock> fsbs = dmlp.getFunctionStatementBlocks(fnspace);
- Iterator<Entry<String, FunctionStatementBlock>> iter = fsbs.entrySet().iterator();
- while( iter.hasNext() ) {
- Entry<String, FunctionStatementBlock> e = iter.next();
- if( !fgraph.isReachableFunction(fnspace, e.getKey()) ) {
- iter.remove();
- if( LOG.isDebugEnabled() )
- LOG.debug("IPA: Removed unused function: " +
- DMLProgram.constructFunctionKey(fnspace, e.getKey()));
- }
- }
- }
- }
-
-
- /////////////////////////////
- // FLAG FUNCTIONS FOR RECOMPILE_ONCE
- //////
-
- /**
- * TODO call it after construct lops
- *
- * @param dmlp the DML program
- * @param fgraph the function call graph
- * @throws LanguageException if LanguageException occurs
- */
- public void flagFunctionsForRecompileOnce( DMLProgram dmlp, FunctionCallGraph fgraph )
- throws LanguageException
- {
- for (String namespaceKey : dmlp.getNamespaces().keySet())
- for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet())
- {
- FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey,fname);
- if( !fgraph.isRecursiveFunction(namespaceKey, fname) &&
- rFlagFunctionForRecompileOnce( fsblock, false ) )
- {
- fsblock.setRecompileOnce( true );
- if( LOG.isDebugEnabled() )
- LOG.debug("IPA: FUNC flagged for recompile-once: " +
- DMLProgram.constructFunctionKey(namespaceKey, fname));
- }
- }
- }
-
- /**
- * Returns true if this statementblock requires recompilation inside a
- * loop statement block.
- *
- * @param sb statement block
- * @param inLoop true if in loop
- * @return true if statement block requires recompilation inside a loop statement block
- */
- public boolean rFlagFunctionForRecompileOnce( StatementBlock sb, boolean inLoop )
- {
- boolean ret = false;
-
- if (sb instanceof FunctionStatementBlock)
- {
- FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
- FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
- for( StatementBlock c : fstmt.getBody() )
- ret |= rFlagFunctionForRecompileOnce( c, inLoop );
- }
- else if (sb instanceof WhileStatementBlock)
- {
- //recompilation information not available at this point
- ret = true;
-
- /*
- WhileStatementBlock wsb = (WhileStatementBlock) sb;
- WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
- ret |= (inLoop && wsb.requiresPredicateRecompilation() );
- for( StatementBlock c : wstmt.getBody() )
- ret |= rFlagFunctionForRecompileOnce( c, true );
- */
- }
- else if (sb instanceof IfStatementBlock)
- {
- IfStatementBlock isb = (IfStatementBlock) sb;
- IfStatement istmt = (IfStatement)isb.getStatement(0);
- ret |= (inLoop && isb.requiresPredicateRecompilation() );
- for( StatementBlock c : istmt.getIfBody() )
- ret |= rFlagFunctionForRecompileOnce( c, inLoop );
- for( StatementBlock c : istmt.getElseBody() )
- ret |= rFlagFunctionForRecompileOnce( c, inLoop );
- }
- else if (sb instanceof ForStatementBlock)
- {
- //recompilation information not available at this point
- ret = true;
-
- /*
- ForStatementBlock fsb = (ForStatementBlock) sb;
- ForStatement fstmt = (ForStatement)fsb.getStatement(0);
- for( StatementBlock c : fstmt.getBody() )
- ret |= rFlagFunctionForRecompileOnce( c, true );
- */
- }
- else
- {
- ret |= ( inLoop && sb.requiresRecompilation() );
- }
-
- return ret;
- }
-
- /////////////////////////////
- // REMOVE UNNECESSARY CHECKPOINTS
- //////
-
- private void removeCheckpointBeforeUpdate(DMLProgram dmlp)
- throws HopsException
- {
- //approach: scan over top-level program (guaranteed to be unconditional),
- //collect checkpoints; determine if used before update; remove first checkpoint
- //on second checkpoint if update in between and not used before update
-
- HashMap<String, Hop> chkpointCand = new HashMap<String, Hop>();
-
- for( StatementBlock sb : dmlp.getStatementBlocks() )
- {
- //prune candidates (used before updated)
- Set<String> cands = new HashSet<String>(chkpointCand.keySet());
- for( String cand : cands )
- if( sb.variablesRead().containsVariable(cand)
- && !sb.variablesUpdated().containsVariable(cand) )
- {
- //note: variableRead might include false positives due to meta
- //data operations like nrow(X) or operations removed by rewrites
- //double check hops on basic blocks; otherwise worst-case
- boolean skipRemove = false;
- if( sb.get_hops() !=null ) {
- Hop.resetVisitStatus(sb.get_hops());
- skipRemove = true;
- for( Hop root : sb.get_hops() )
- skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false);
- }
- if( !skipRemove )
- chkpointCand.remove(cand);
- }
-
- //prune candidates (updated in conditional control flow)
- Set<String> cands2 = new HashSet<String>(chkpointCand.keySet());
- if( sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock
- || sb instanceof ForStatementBlock )
- {
- for( String cand : cands2 )
- if( sb.variablesUpdated().containsVariable(cand) ) {
- chkpointCand.remove(cand);
- }
- }
- //prune candidates (updated w/ multiple reads)
- else
- {
- for( String cand : cands2 )
- if( sb.variablesUpdated().containsVariable(cand) && sb.get_hops() != null)
- {
- Hop.resetVisitStatus(sb.get_hops());
- for( Hop root : sb.get_hops() )
- if( root.getName().equals(cand) &&
- !HopRewriteUtils.rHasSimpleReadChain(root, cand) ) {
- chkpointCand.remove(cand);
- }
- }
- }
-
- //collect checkpoints and remove unnecessary checkpoints
- ArrayList<Hop> tmp = collectCheckpoints(sb.get_hops());
- for( Hop chkpoint : tmp ) {
- if( chkpointCand.containsKey(chkpoint.getName()) ) {
- chkpointCand.get(chkpoint.getName()).setRequiresCheckpoint(false);
- }
- chkpointCand.put(chkpoint.getName(), chkpoint);
- }
-
- }
- }
-
- private void moveCheckpointAfterUpdate(DMLProgram dmlp)
- throws HopsException
- {
- //approach: scan over top-level program (guaranteed to be unconditional),
- //collect checkpoints; determine if used before update; move first checkpoint
- //after update if not used before update (best effort move which often avoids
- //the second checkpoint on loops even though used in between)
-
- HashMap<String, Hop> chkpointCand = new HashMap<String, Hop>();
-
- for( StatementBlock sb : dmlp.getStatementBlocks() )
- {
- //prune candidates (used before updated)
- Set<String> cands = new HashSet<String>(chkpointCand.keySet());
- for( String cand : cands )
- if( sb.variablesRead().containsVariable(cand)
- && !sb.variablesUpdated().containsVariable(cand) )
- {
- //note: variableRead might include false positives due to meta
- //data operations like nrow(X) or operations removed by rewrites
- //double check hops on basic blocks; otherwise worst-case
- boolean skipRemove = false;
- if( sb.get_hops() !=null ) {
- Hop.resetVisitStatus(sb.get_hops());
- skipRemove = true;
- for( Hop root : sb.get_hops() )
- skipRemove &= !HopRewriteUtils.rContainsRead(root, cand, false);
- }
- if( !skipRemove )
- chkpointCand.remove(cand);
- }
-
- //prune candidates (updated in conditional control flow)
- Set<String> cands2 = new HashSet<String>(chkpointCand.keySet());
- if( sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock
- || sb instanceof ForStatementBlock )
- {
- for( String cand : cands2 )
- if( sb.variablesUpdated().containsVariable(cand) ) {
- chkpointCand.remove(cand);
- }
- }
- //move checkpoint after update with simple read chain
- //(note: right now this only applies if the checkpoints comes from a previous
- //statement block, within-dag checkpoints should be handled during injection)
- else
- {
- for( String cand : cands2 )
- if( sb.variablesUpdated().containsVariable(cand) && sb.get_hops() != null) {
- Hop.resetVisitStatus(sb.get_hops());
- for( Hop root : sb.get_hops() )
- if( root.getName().equals(cand) ) {
- if( HopRewriteUtils.rHasSimpleReadChain(root, cand) ) {
- chkpointCand.get(cand).setRequiresCheckpoint(false);
- root.getInput().get(0).setRequiresCheckpoint(true);
- chkpointCand.put(cand, root.getInput().get(0));
- }
- else
- chkpointCand.remove(cand);
- }
- }
- }
-
- //collect checkpoints
- ArrayList<Hop> tmp = collectCheckpoints(sb.get_hops());
- for( Hop chkpoint : tmp ) {
- chkpointCand.put(chkpoint.getName(), chkpoint);
- }
- }
- }
-
- private void removeCheckpointReadWrite(DMLProgram dmlp)
- throws HopsException
- {
- List<StatementBlock> sbs = dmlp.getStatementBlocks();
-
- if( sbs.size()==1 & !(sbs.get(0) instanceof IfStatementBlock
- || sbs.get(0) instanceof WhileStatementBlock
- || sbs.get(0) instanceof ForStatementBlock) )
- {
- //recursively process all dag roots
- if( sbs.get(0).get_hops()!=null ) {
- Hop.resetVisitStatus(sbs.get(0).get_hops());
- for( Hop root : sbs.get(0).get_hops() )
- rRemoveCheckpointReadWrite(root);
- }
- }
- }
-
- private ArrayList<Hop> collectCheckpoints(ArrayList<Hop> roots)
- {
- ArrayList<Hop> ret = new ArrayList<Hop>();
- if( roots != null ) {
- Hop.resetVisitStatus(roots);
- for( Hop root : roots )
- rCollectCheckpoints(root, ret);
- }
-
- return ret;
- }
-
- private void rCollectCheckpoints(Hop hop, ArrayList<Hop> checkpoints)
- {
- if( hop.isVisited() )
- return;
-
- //handle leaf node for variable (checkpoint directly bound
- //to logical variable name and not used)
- if( hop.requiresCheckpoint() && hop.getParent().size()==1
- && hop.getParent().get(0) instanceof DataOp
- && ((DataOp)hop.getParent().get(0)).getDataOpType()==DataOpTypes.TRANSIENTWRITE)
- {
- checkpoints.add(hop);
- }
-
- //recursively process child nodes
- for( Hop c : hop.getInput() )
- rCollectCheckpoints(c, checkpoints);
-
- hop.setVisited();
- }
-
- public static void rRemoveCheckpointReadWrite(Hop hop)
- {
- if( hop.isVisited() )
- return;
-
- //remove checkpoint on pread if only consumed by pwrite or uagg
- if( (hop instanceof DataOp && ((DataOp)hop).getDataOpType()==DataOpTypes.PERSISTENTWRITE)
- || hop instanceof AggUnaryOp )
- {
- //(pwrite|uagg) - pread
- Hop c0 = hop.getInput().get(0);
- if( c0.requiresCheckpoint() && c0.getParent().size() == 1
- && c0 instanceof DataOp && ((DataOp)c0).getDataOpType()==DataOpTypes.PERSISTENTREAD )
- {
- c0.setRequiresCheckpoint(false);
- }
-
- //(pwrite|uagg) - frame/matri cast - pread
- if( c0 instanceof UnaryOp && c0.getParent().size() == 1
- && (((UnaryOp)c0).getOp()==OpOp1.CAST_AS_FRAME || ((UnaryOp)c0).getOp()==OpOp1.CAST_AS_MATRIX )
- && c0.getInput().get(0).requiresCheckpoint() && c0.getInput().get(0).getParent().size() == 1
- && c0.getInput().get(0) instanceof DataOp
- && ((DataOp)c0.getInput().get(0)).getDataOpType()==DataOpTypes.PERSISTENTREAD )
- {
- c0.getInput().get(0).setRequiresCheckpoint(false);
- }
- }
-
- //recursively process children
- for( Hop c : hop.getInput() )
- rRemoveCheckpointReadWrite(c);
-
- hop.setVisited();
- }
-
- /////////////////////////////
- // REMOVE CONSTANT BINARY OPS
- //////
-
- private void removeConstantBinaryOps(DMLProgram dmlp)
- throws HopsException
- {
- //approach: scan over top-level program (guaranteed to be unconditional),
- //collect ones=matrix(1,...); remove b(*)ones if not outer operation
- HashMap<String, Hop> mOnes = new HashMap<String, Hop>();
-
- for( StatementBlock sb : dmlp.getStatementBlocks() )
- {
- //pruning updated variables
- for( String var : sb.variablesUpdated().getVariableNames() )
- if( mOnes.containsKey( var ) )
- mOnes.remove( var );
-
- //replace constant binary ops
- if( !mOnes.isEmpty() )
- rRemoveConstantBinaryOp(sb, mOnes);
-
- //collect matrices of ones from last-level statement blocks
- if( !(sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock
- || sb instanceof ForStatementBlock) )
- {
- collectMatrixOfOnes(sb.get_hops(), mOnes);
- }
- }
- }
-
- private void collectMatrixOfOnes(ArrayList<Hop> roots, HashMap<String,Hop> mOnes)
- {
- if( roots == null )
- return;
-
- for( Hop root : roots )
- if( root instanceof DataOp && ((DataOp)root).getDataOpType()==DataOpTypes.TRANSIENTWRITE
- && root.getInput().get(0) instanceof DataGenOp
- && ((DataGenOp)root.getInput().get(0)).getOp()==DataGenMethod.RAND
- && ((DataGenOp)root.getInput().get(0)).hasConstantValue(1.0))
- {
- mOnes.put(root.getName(),root.getInput().get(0));
- }
- }
-
- private void rRemoveConstantBinaryOp(StatementBlock sb, HashMap<String,Hop> mOnes)
- throws HopsException
- {
- if( sb instanceof IfStatementBlock )
- {
- IfStatementBlock isb = (IfStatementBlock) sb;
- IfStatement istmt = (IfStatement)isb.getStatement(0);
- for( StatementBlock c : istmt.getIfBody() )
- rRemoveConstantBinaryOp(c, mOnes);
- if( istmt.getElseBody() != null )
- for( StatementBlock c : istmt.getElseBody() )
- rRemoveConstantBinaryOp(c, mOnes);
- }
- else if( sb instanceof WhileStatementBlock )
- {
- WhileStatementBlock wsb = (WhileStatementBlock) sb;
- WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
- for( StatementBlock c : wstmt.getBody() )
- rRemoveConstantBinaryOp(c, mOnes);
- }
- else if( sb instanceof ForStatementBlock )
- {
- ForStatementBlock fsb = (ForStatementBlock) sb;
- ForStatement fstmt = (ForStatement)fsb.getStatement(0);
- for( StatementBlock c : fstmt.getBody() )
- rRemoveConstantBinaryOp(c, mOnes);
- }
- else
- {
- if( sb.get_hops() != null ){
- Hop.resetVisitStatus(sb.get_hops());
- for( Hop hop : sb.get_hops() )
- rRemoveConstantBinaryOp(hop, mOnes);
- }
- }
- }
-
- private void rRemoveConstantBinaryOp(Hop hop, HashMap<String,Hop> mOnes)
- {
- if( hop.isVisited() )
- return;
-
- if( hop instanceof BinaryOp && ((BinaryOp)hop).getOp()==OpOp2.MULT
- && !((BinaryOp) hop).isOuterVectorOperator()
- && hop.getInput().get(0).getDataType()==DataType.MATRIX
- && hop.getInput().get(1) instanceof DataOp
- && mOnes.containsKey(hop.getInput().get(1).getName()) )
- {
- //replace matrix of ones with literal 1 (later on removed by
- //algebraic simplification rewrites; otherwise more complex
- //recursive processing of childs and rewiring required)
- HopRewriteUtils.removeChildReferenceByPos(hop, hop.getInput().get(1), 1);
- HopRewriteUtils.addChildReference(hop, new LiteralOp(1), 1);
- }
-
- //recursively process child nodes
- for( Hop c : hop.getInput() )
- rRemoveConstantBinaryOp(c, mOnes);
-
- hop.setVisited();
- }
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/e9fb7a02/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 2b4bfa0..47446f6 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -265,8 +265,8 @@ public class DMLTranslator
//propagate size information from main into functions (but conservatively)
if( OptimizerUtils.ALLOW_INTER_PROCEDURAL_ANALYSIS ) {
- InterProceduralAnalysis ipa = new InterProceduralAnalysis();
- ipa.analyzeProgram(dmlp);
+ InterProceduralAnalysis ipa = new InterProceduralAnalysis(dmlp);
+ ipa.analyzeProgram();
resetHopsDAGVisitStatus(dmlp);
if (OptimizerUtils.ALLOW_IPA_SECOND_CHANCE) {
// SECOND CHANCE:
@@ -275,7 +275,7 @@ public class DMLTranslator
// and then further scalar -> literal replacement (IPA).
rewriter.rewriteProgramHopDAGs(dmlp);
resetHopsDAGVisitStatus(dmlp);
- ipa.analyzeProgram(dmlp);
+ ipa.analyzeProgram();
resetHopsDAGVisitStatus(dmlp);
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/e9fb7a02/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
index 2469005..92dd567 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizationWrapper.java
@@ -204,8 +204,8 @@ public class OptimizationWrapper
//inter-procedural optimization (based on previous recompilation)
if( pb.hasFunctions() ) {
- InterProceduralAnalysis ipa = new InterProceduralAnalysis();
- Set<String> fcand = ipa.analyzeSubProgram(sb);
+ InterProceduralAnalysis ipa = new InterProceduralAnalysis(sb);
+ Set<String> fcand = ipa.analyzeSubProgram();
if( !fcand.isEmpty() ) {
//regenerate runtime program of modified functions