You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/06/19 19:38:02 UTC

[2/2] systemml git commit: [SYSTEMML-1719] New codegen common-subexpression elimination for cplans

[SYSTEMML-1719] New codegen common-subexpression elimination for cplans

This patch introduces a general-purpose CSE rewrites for code generation
plans and applies it during cleanup of cplans before hop dag
modification. The advantages are better generated code (without
unnecessary lookups as often encountered in multi-agg templates) and
better plan cache hit rates.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/9389a5e1
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/9389a5e1
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/9389a5e1

Branch: refs/heads/master
Commit: 9389a5e1e0bd081ef0037321a7d1e7eac328cfbe
Parents: 0afaa24
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun Jun 18 23:59:22 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Mon Jun 19 12:38:21 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/SpoofCompiler.java       |  49 ++++-----
 .../apache/sysml/hops/codegen/cplan/CNode.java  |   4 +
 .../sysml/hops/codegen/cplan/CNodeData.java     |  10 +-
 .../sysml/hops/codegen/cplan/CNodeTpl.java      |  16 +--
 .../hops/codegen/template/CPlanCSERewriter.java | 100 +++++++++++++++++++
 .../hops/codegen/template/TemplateCell.java     |   3 +-
 6 files changed, 140 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/9389a5e1/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index 988af7c..56df6fc 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -46,6 +46,7 @@ import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
 import org.apache.sysml.hops.codegen.template.TemplateBase;
 import org.apache.sysml.hops.codegen.template.TemplateBase.CloseType;
 import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysml.hops.codegen.template.CPlanCSERewriter;
 import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
 import org.apache.sysml.hops.codegen.template.PlanSelection;
 import org.apache.sysml.hops.codegen.template.PlanSelectionFuseCostBased;
@@ -347,7 +348,8 @@ public class SpoofCompiler
 			HashMap<Long, Pair<Hop[],CNodeTpl>>  cplans = constructCPlans(roots, compileLiterals);
 			
 			//cleanup codegen plans (remove unnecessary inputs, fix hop-cnodedata mapping,
-			//remove empty templates with single cnodedata input, remove spurious lookups)
+			//remove empty templates with single cnodedata input, remove spurious lookups,
+			//perform common subexpression elimination)
 			cplans = cleanupCPlans(cplans);
 			
 			//explain before modification
@@ -663,33 +665,26 @@ public class SpoofCompiler
 	 * 
 	 * @param cplans set of cplans
 	 */
-	private static HashMap<Long, Pair<Hop[],CNodeTpl>> cleanupCPlans(HashMap<Long, Pair<Hop[],CNodeTpl>> cplans) {
+	private static HashMap<Long, Pair<Hop[],CNodeTpl>> cleanupCPlans(HashMap<Long, Pair<Hop[],CNodeTpl>> cplans) 
+	{
 		HashMap<Long, Pair<Hop[],CNodeTpl>> cplans2 = new HashMap<Long, Pair<Hop[],CNodeTpl>>();
+		CPlanCSERewriter cse = new CPlanCSERewriter();
+		
 		for( Entry<Long, Pair<Hop[],CNodeTpl>> e : cplans.entrySet() ) {
 			CNodeTpl tpl = e.getValue().getValue();
 			Hop[] inHops = e.getValue().getKey();
 			
-			//collect cplan leaf node names
-			HashSet<Long> leafs = new HashSet<Long>();
-			if( tpl instanceof CNodeMultiAgg )
-				for( CNode out : ((CNodeMultiAgg)tpl).getOutputs() )
-					rCollectLeafIDs(out, leafs);
-			else
-				rCollectLeafIDs(tpl.getOutput(), leafs);
+			//perform common subexpression elimination
+			tpl = cse.eliminateCommonSubexpressions(tpl);
 			
-			//create clean cplan w/ minimal inputs
-			if( inHops.length == leafs.size() )
-				cplans2.put(e.getKey(), e.getValue());
-			else {
-				tpl.cleanupInputs(leafs);
-				ArrayList<Hop> tmp = new ArrayList<Hop>();
-				for( Hop hop : inHops ) {
-					if( hop!= null && leafs.contains(hop.getHopID()) )
-						tmp.add(hop);
-				}
-				cplans2.put(e.getKey(), new Pair<Hop[],CNodeTpl>(
-						tmp.toArray(new Hop[0]),tpl));
-			}
+			//update input hops (order-preserving)
+			HashSet<Long> inputHopIDs = tpl.getInputHopIDs(false);
+			ArrayList<Hop> tmp = new ArrayList<Hop>();
+			for( Hop input : inHops )
+				if( inputHopIDs.contains(input.getHopID()) )
+					tmp.add(input);
+			inHops = tmp.toArray(new Hop[0]);
+			cplans2.put(e.getKey(), new Pair<Hop[],CNodeTpl>(inHops, tpl));
 			
 			//remove invalid plans with column indexing on main input
 			if( tpl instanceof CNodeCell ) {
@@ -734,16 +729,6 @@ public class SpoofCompiler
 		return cplans2;
 	}
 	
-	private static void rCollectLeafIDs(CNode node, HashSet<Long> leafs) {
-		//collect leaf variable names
-		if( node instanceof CNodeData && !((CNodeData)node).isLiteral() )
-			leafs.add(((CNodeData) node).getHopID());
-		
-		//recursively process cplan
-		for( CNode c : node.getInput() )
-			rCollectLeafIDs(c, leafs);
-	}
-	
 	private static void rFindAndRemoveLookupMultiAgg(CNodeMultiAgg node, CNodeData mainInput) {
 		//process all outputs individually
 		for( CNode output : node.getOutputs() )

http://git-wip-us.apache.org/repos/asf/systemml/blob/9389a5e1/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
index df40d58..1a0617a 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNode.java
@@ -83,6 +83,10 @@ public abstract class CNode
 		_generated = false;
 	}
 	
+	public void resetHash() {
+		_hash = 0;
+	}
+	
 	public void setNumRows(long rows) {
 		_rows = rows;
 	}

http://git-wip-us.apache.org/repos/asf/systemml/blob/9389a5e1/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java
index 88baa61..893554c 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeData.java
@@ -28,6 +28,7 @@ public class CNodeData extends CNode
 {
 	protected final String _name;
 	protected final long _hopID;
+	private boolean _strictEquals;
 	
 	public CNodeData(Hop hop) {
 		this(hop, hop.getDim1(), hop.getDim2(), hop.getDataType());
@@ -36,6 +37,7 @@ public class CNodeData extends CNode
 	public CNodeData(Hop hop, long rows, long cols, DataType dt) {
 		//note: previous rewrites might have created hops with equal name
 		//hence, we also keep the hopID to uniquely identify inputs
+		super();
 		_name = hop.getName();
 		_hopID = hop.getHopID();
 		_rows = rows;
@@ -67,6 +69,11 @@ public class CNodeData extends CNode
 		return _hopID;
 	}
 	
+	public void setStrictEquals(boolean flag) {
+		_strictEquals = flag;
+		_hash = 0;
+	}
+	
 	@Override
 	public String codegen(boolean sparse) {
 		return "";
@@ -97,6 +104,7 @@ public class CNodeData extends CNode
 		return (o instanceof CNodeData 
 			&& super.equals(o)
 			&& isLiteral() == ((CNodeData)o).isLiteral()
-			&& (isLiteral() ? _name.equals(((CNodeData)o)._name) : true));
+			&& (isLiteral() ? _name.equals(((CNodeData)o)._name) : 
+			_strictEquals ? _hopID == ((CNodeData)o)._hopID : true));
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/9389a5e1/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
index 81351e6..a4dae72 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
@@ -51,14 +51,6 @@ public abstract class CNodeTpl extends CNode implements Cloneable
 		_inputs.add(in);
 	}
 	
-	public void cleanupInputs(HashSet<Long> filter) {
-		ArrayList<CNode> tmp = new ArrayList<CNode>();
-		for( CNode in : _inputs )
-			if( in instanceof CNodeData && filter.contains(((CNodeData) in).getHopID()) )
-				tmp.add(in);
-		_inputs = tmp;
-	}
-	
 	public String[] getInputNames() {
 		String[] ret = new String[_inputs.size()];
 		for( int i=0; i<_inputs.size(); i++ )
@@ -66,6 +58,14 @@ public abstract class CNodeTpl extends CNode implements Cloneable
 		return ret;
 	}
 	
+	public HashSet<Long> getInputHopIDs(boolean inclLiterals) {
+		HashSet<Long> ret = new HashSet<Long>();
+		for( CNode input : _inputs )
+			if( !input.isLiteral() || inclLiterals )
+				ret.add(((CNodeData)input).getHopID());
+		return ret;
+	}
+	
 	public void resetVisitStatusOutputs() {
 		getOutput().resetVisitStatus();
 	}

http://git-wip-us.apache.org/repos/asf/systemml/blob/9389a5e1/src/main/java/org/apache/sysml/hops/codegen/template/CPlanCSERewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanCSERewriter.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanCSERewriter.java
new file mode 100644
index 0000000..95f0ed7
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanCSERewriter.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.template;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+
+import org.apache.sysml.hops.codegen.cplan.CNode;
+import org.apache.sysml.hops.codegen.cplan.CNodeData;
+import org.apache.sysml.hops.codegen.cplan.CNodeMultiAgg;
+import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
+
+public class CPlanCSERewriter 
+{
+	public CNodeTpl eliminateCommonSubexpressions(CNodeTpl tpl) 
+	{
+		//Note: Compared to our traditional common subexpression elimination, on cplans, 
+		//we don't have any parent references, and hence cannot use a collect-merge approach. 
+		//In contrast, we exploit the hash signatures of cnodes as used in the plan cache. 
+		//However, note that these signatures ignore input hops by default (for better plan 
+		//cache hit rates), but are temporarily set to strict evaluation for this rewrite. 
+		
+		List<CNode> outputs = (tpl instanceof CNodeMultiAgg) ? 
+			((CNodeMultiAgg)tpl).getOutputs() : 
+			Arrays.asList(tpl.getOutput());
+		
+		//step 1: set data nodes to strict comparison
+		HashSet<Long> memo = new HashSet<Long>();
+		for( CNode out : outputs )
+			rSetStrictDataNodeComparision(out, memo, true);
+		
+		//step 2: perform common subexpression elimination
+		HashMap<CNode,CNode> cseSet = new HashMap<CNode,CNode>();
+		memo.clear();
+		for( CNode out : outputs )
+			rEliminateCommonSubexpression(out, cseSet, memo);
+		
+		//step 3: reset data nodes to imprecise comparison
+		memo.clear();
+		for( CNode out : outputs )
+			rSetStrictDataNodeComparision(out, memo, true);
+		
+		return tpl;
+	}
+	
+	private void rEliminateCommonSubexpression(CNode current, HashMap<CNode,CNode> cseSet, HashSet<Long> memo) {
+		//avoid redundant re-evaluation
+		if( memo.contains(current.getID()) )
+			return;
+		
+		//replace input with existing common subexpression
+		for( int i=0; i<current.getInput().size(); i++ ) {
+			CNode input = current.getInput().get(i);
+			if( cseSet.containsKey(input) )
+				current.getInput().set(i, cseSet.get(input));
+		}
+		
+		//process inputs recursively
+		for( CNode input : current.getInput() )
+			rEliminateCommonSubexpression(input, cseSet, memo);
+		
+		//process node itself
+		cseSet.put(current, current);
+		memo.add(current.getID());
+	}
+	
+	private void rSetStrictDataNodeComparision(CNode current, HashSet<Long> memo, boolean flag) {
+		//avoid redundant re-evaluation
+		if( memo.contains(current.getID()) )
+			return;
+		
+		//process inputs recursively and node itself
+		for( CNode input : current.getInput() ) {
+			rSetStrictDataNodeComparision(input, memo, flag);
+			input.resetHash();
+		}
+		if( current instanceof CNodeData )
+			((CNodeData)current).setStrictEquals(flag);
+		memo.add(current.getID());	
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/9389a5e1/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index 5455775..26d477d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -163,7 +163,8 @@ public class TemplateCell extends TemplateBase
 			if( me!=null && me.isPlanRef(i) && !(c instanceof DataOp)
 				&& (me.type!=TemplateType.MultiAggTpl || memo.contains(c.getHopID(), TemplateType.CellTpl)))
 				rConstructCplan(c, memo, tmp, inHops, compileLiterals);
-			else if( me!=null && me.type==TemplateType.MultiAggTpl && HopRewriteUtils.isMatrixMultiply(hop) && i==0 )
+			else if( me!=null && (me.type==TemplateType.MultiAggTpl || me.type==TemplateType.CellTpl) 
+					&& HopRewriteUtils.isMatrixMultiply(hop) && i==0 ) //skip transpose
 				rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals);
 			else {
 				CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);