You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/24 00:49:32 UTC

[2/2] systemml git commit: [SYSTEMML-1801] Fix incomplete codegen candidate exploration

[SYSTEMML-1801] Fix incomplete codegen candidate exploration

This patch fixes various issues of the codegen candidate exploration
step. As it turned out, we missed multiple candidate plans which led to
suboptimal fused operators. Furthermore, this patch also improves the
printing of memo table entries to simply debugging.

Furthermore, this also includes various fixes for code generation of
unary CNode operations (e.g., input var handling).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ee2b37e4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ee2b37e4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ee2b37e4

Branch: refs/heads/master
Commit: ee2b37e4fa02b48867e08e0f9401a026cf54b35c
Parents: dd6f46a
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun Jul 23 15:26:55 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sun Jul 23 17:05:27 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/api/jmlc/Connection.java   | 13 +--
 .../sysml/hops/codegen/SpoofCompiler.java       | 24 ++---
 .../sysml/hops/codegen/cplan/CNodeUnary.java    | 21 +++--
 .../hops/codegen/template/CPlanMemoTable.java   | 92 +++++++++++++-------
 .../template/PlanSelectionFuseCostBased.java    |  4 +-
 .../sysml/runtime/util/UtilFunctions.java       | 11 +++
 6 files changed, 105 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ee2b37e4/src/main/java/org/apache/sysml/api/jmlc/Connection.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/jmlc/Connection.java b/src/main/java/org/apache/sysml/api/jmlc/Connection.java
index 1993ed4..f933396 100644
--- a/src/main/java/org/apache/sysml/api/jmlc/Connection.java
+++ b/src/main/java/org/apache/sysml/api/jmlc/Connection.java
@@ -27,11 +27,8 @@ import java.io.InputStream;
 import java.io.InputStreamReader;
 import java.util.Arrays;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.Map;
-import java.util.Set;
 
-import org.apache.commons.collections.CollectionUtils;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.sysml.api.DMLException;
@@ -66,6 +63,7 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.transform.TfUtils;
 import org.apache.sysml.runtime.transform.meta.TfMetaUtils;
 import org.apache.sysml.runtime.util.DataConverter;
+import org.apache.sysml.runtime.util.UtilFunctions;
 import org.apache.wink.json4j.JSONObject;
 
 /**
@@ -190,7 +188,7 @@ public class Connection implements Closeable
 			throw new LanguageException("Invalid argument names: "+Arrays.toString(invalidArgs));
 		
 		//check for valid names of input and output variables
-		String[] invalidVars = asSet(inputs, outputs).stream()
+		String[] invalidVars = UtilFunctions.asSet(inputs, outputs).stream()
 			.filter(k -> k==null || k.startsWith("$")).toArray(String[]::new);
 		if( invalidVars.length > 0 )
 			throw new LanguageException("Invalid variable names: "+Arrays.toString(invalidVars));
@@ -846,11 +844,4 @@ public class Connection implements Closeable
 	public FrameBlock readTransformMetaDataFromPath(String spec, String metapath, String colDelim) throws IOException {
 		return TfMetaUtils.readTransformMetaDataFromPath(spec, metapath, colDelim);
 	}
-	
-	private Set<String> asSet(String[] inputs, String[] outputs) {
-		Set<String> ret = new HashSet<String>();
-		CollectionUtils.addAll(ret, inputs);
-		CollectionUtils.addAll(ret, outputs);
-		return ret;
-	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ee2b37e4/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index 4a59d1b..8ab2240 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -124,7 +124,11 @@ public class SpoofCompiler
 	public enum PlanSelector {
 		FUSE_ALL,             //maximal fusion, possible w/ redundant compute
 		FUSE_NO_REDUNDANCY,   //fusion without redundant compute 
-		FUSE_COST_BASED,      //cost-based decision on materialization points
+		FUSE_COST_BASED;      //cost-based decision on materialization points
+		public boolean isHeuristic() {
+			return this == FUSE_ALL
+				|| this == FUSE_NO_REDUNDANCY;
+		}
 	}
 
 	public enum PlanCachePolicy {
@@ -503,7 +507,7 @@ public class SpoofCompiler
 		//open initial operator plans, if possible
 		for( TemplateBase tpl : TemplateUtils.TEMPLATES )
 			if( tpl.open(hop) ) {
-				MemoTableEntrySet P = new MemoTableEntrySet(tpl.getType(), false);
+				MemoTableEntrySet P = new MemoTableEntrySet(hop, tpl.getType(), false);
 				memo.addAll(hop, enumPlans(hop, -1, P, tpl, memo));
 			}
 		
@@ -514,16 +518,12 @@ public class SpoofCompiler
 					TemplateBase tpl = TemplateUtils.createTemplate(me.type, me.closed);
 					if( tpl.fuse(hop, c) ) {
 						int pos = hop.getInput().indexOf(c);
-						MemoTableEntrySet P = new MemoTableEntrySet(tpl.getType(), pos, c.getHopID(), tpl.isClosed());
+						MemoTableEntrySet P = new MemoTableEntrySet(hop, tpl.getType(), pos, c.getHopID(), tpl.isClosed());
 						memo.addAll(hop, enumPlans(hop, pos, P, tpl, memo));
 					}
 				}	
 		}
 		
-		//prune subsumed / redundant plans
-		if( PRUNE_REDUNDANT_PLANS )
-			memo.pruneRedundant(hop.getHopID());
-		
 		//close operator plans, if required
 		if( memo.contains(hop.getHopID()) ) {
 			Iterator<MemoTableEntry> iter = memo.get(hop.getHopID()).iterator();
@@ -538,6 +538,10 @@ public class SpoofCompiler
 			}
 		}
 		
+		//prune subsumed / redundant plans
+		if( PRUNE_REDUNDANT_PLANS )
+			memo.pruneRedundant(hop.getHopID());
+		
 		//mark visited even if no plans found (e.g., unsupported ops)
 		memo.addHop(hop);
 	}
@@ -546,10 +550,8 @@ public class SpoofCompiler
 		for(int k=0; k<hop.getInput().size(); k++)
 			if( k != pos ) {
 				Hop input2 = hop.getInput().get(k);
-				if( memo.contains(input2.getHopID()) && !memo.get(input2.getHopID()).get(0).closed
-					&& TemplateUtils.isType(memo.get(input2.getHopID()).get(0).type, tpl.getType(), TemplateType.CellTpl)
-					&& tpl.merge(hop, input2) && (tpl.getType()!=TemplateType.RowTpl || pos==-1 
-						|| TemplateUtils.hasCommonRowTemplateMatrixInput(hop.getInput().get(pos), input2, memo)))
+				if( memo.contains(input2.getHopID(), true, tpl.getType(), TemplateType.CellTpl) 
+					&& tpl.merge(hop, input2) )
 					P.crossProduct(k, -1L, input2.getHopID());
 				else
 					P.crossProduct(k, -1L);

http://git-wip-us.apache.org/repos/asf/systemml/blob/ee2b37e4/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index 1a36604..02f00b8 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -20,6 +20,7 @@
 package org.apache.sysml.hops.codegen.cplan;
 
 import org.apache.commons.lang.StringUtils;
+import org.apache.sysml.hops.codegen.template.TemplateUtils;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.runtime.util.UtilFunctions;
 
@@ -143,6 +144,12 @@ public class CNodeUnary extends CNode
 			String [] tmp = this.name().split("_");
 			return StringUtils.capitalize(tmp[1].toLowerCase());
 		}
+		public boolean isScalarLookup() {
+			return this == LOOKUP0 
+				|| this == UnaryType.LOOKUP_R
+				|| this == UnaryType.LOOKUP_C
+				|| this == UnaryType.LOOKUP_RC;
+		}
 	}
 	
 	private UnaryType _type;
@@ -172,7 +179,9 @@ public class CNodeUnary extends CNode
 		sb.append(_inputs.get(0).codegen(sparse));
 		
 		//generate unary operation
-		boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData);
+		boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData
+			&& !_inputs.get(0).getVarname().startsWith("b")
+			&& !_inputs.get(0).isLiteral());
 		String var = createVarname();
 		String tmp = _type.getTemplate(lsparse);
 		tmp = tmp.replace("%TMP%", var);
@@ -182,12 +191,14 @@ public class CNodeUnary extends CNode
 		//replace sparse and dense inputs
 		tmp = tmp.replace("%IN1v%", varj+"vals");
 		tmp = tmp.replace("%IN1i%", varj+"ix");
-		tmp = tmp.replace("%IN1%", varj );
+		tmp = tmp.replace("%IN1%", varj.startsWith("b") && !_type.isScalarLookup()
+			&& TemplateUtils.isMatrix(_inputs.get(0)) ? varj + ".ddat" : varj );
 		
 		//replace start position of main input
-		String spos = (!varj.startsWith("b") 
-			&& _inputs.get(0) instanceof CNodeData 
-			&& _inputs.get(0).getDataType().isMatrix()) ? varj+"i" : "0";
+		String spos = (_inputs.get(0) instanceof CNodeData 
+			&& _inputs.get(0).getDataType().isMatrix()) ? !varj.startsWith("b") ? 
+			varj+"i" : TemplateUtils.isMatrix(_inputs.get(0)) ? "rowIndex*%LEN%" : "0" : "0";
+		
 		tmp = tmp.replace("%POS1%", spos);
 		tmp = tmp.replace("%POS2%", spos);
 		

http://git-wip-us.apache.org/repos/asf/systemml/blob/ee2b37e4/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
index 6982470..074d29c 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
@@ -34,8 +34,10 @@ import java.util.stream.Collectors;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.codegen.SpoofCompiler;
 import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysml.runtime.util.UtilFunctions;
 
 public class CPlanMemoTable 
 {
@@ -64,8 +66,16 @@ public class CPlanMemoTable
 	}
 	
 	public boolean contains(long hopID, TemplateType type) {
-		return contains(hopID) && get(hopID)
-			.stream().anyMatch(p -> p.type==type);
+		return contains(hopID) && get(hopID).stream()
+			.anyMatch(p -> p.type==type);
+	}
+	
+	public boolean contains(long hopID, boolean checkClose, TemplateType... type) {
+		if( !checkClose && type.length==1 )
+			return contains(hopID, type[0]);
+		Set<TemplateType> probe = UtilFunctions.asSet(type);
+		return contains(hopID) && get(hopID).stream()
+			.anyMatch(p -> (!checkClose||!p.closed) && probe.contains(p.type));
 	}
 	
 	public int countEntries(long hopID) {
@@ -95,7 +105,8 @@ public class CPlanMemoTable
 	}
 	
 	public void add(Hop hop, TemplateType type, long in1, long in2, long in3) {
-		add(hop, new MemoTableEntry(type, in1, in2, in3));
+		int size = (hop instanceof IndexingOp) ? 1 : hop.getInput().size();
+		add(hop, new MemoTableEntry(type, in1, in2, in3, size));
 	}
 	
 	public void add(Hop hop, MemoTableEntry me) {
@@ -129,25 +140,31 @@ public class CPlanMemoTable
 		//prune redundant plans (i.e., equivalent) 
 		setDistinct(hopID, _plans.get(hopID));
 		
-		//prune dominated plans (e.g., opened plan subsumed
-		//by fused plan if single consumer of input)
-		HashSet<MemoTableEntry> rmList = new HashSet<MemoTableEntry>();
-		List<MemoTableEntry> list = _plans.get(hopID);
-		Hop hop = _hopRefs.get(hopID);
-		for( MemoTableEntry e1 : list )
-			for( MemoTableEntry e2 : list )
-				if( e1 != e2 && e1.subsumes(e2) ) {
-					//check that childs don't have multiple consumers
-					boolean rmSafe = true; 
-					for( int i=0; i<=2; i++ )
-						rmSafe &= (e1.isPlanRef(i) && !e2.isPlanRef(i)) ?
-							hop.getInput().get(i).getParent().size()==1 : true;
-					if( rmSafe )
-						rmList.add(e2);
-				}
+		//prune closed templates without group references
+		_plans.get(hopID).removeIf(p -> p.closed && !p.hasPlanRef());
 		
-		//update current entry list, by removing rmList
-		remove(hop, rmList);
+		//prune dominated plans (e.g., opened plan subsumed by fused plan 
+		//if single consumer of input; however this only applies to fusion
+		//heuristic that only consider materialization points)
+		if( SpoofCompiler.PLAN_SEL_POLICY.isHeuristic() ) {
+			HashSet<MemoTableEntry> rmList = new HashSet<MemoTableEntry>();
+			List<MemoTableEntry> list = _plans.get(hopID);
+			Hop hop = _hopRefs.get(hopID);
+			for( MemoTableEntry e1 : list )
+				for( MemoTableEntry e2 : list )
+					if( e1 != e2 && e1.subsumes(e2) ) {
+						//check that childs don't have multiple consumers
+						boolean rmSafe = true; 
+						for( int i=0; i<=2; i++ )
+							rmSafe &= (e1.isPlanRef(i) && !e2.isPlanRef(i)) ?
+								hop.getInput().get(i).getParent().size()==1 : true;
+						if( rmSafe )
+							rmList.add(e2);
+					}
+			
+			//update current entry list, by removing rmList
+			remove(hop, rmList);
+		}
 	}
 
 	public void pruneSuboptimal(ArrayList<Hop> roots) {
@@ -204,7 +221,7 @@ public class CPlanMemoTable
 	public List<MemoTableEntry> getDistinct(long hopID) {
 		//return distinct entries wrt type and closed attributes
 		return _plans.get(hopID).stream()
-			.map(p -> new MemoTableEntry(p.type,-1,-1,-1,p.closed))
+			.map(p -> new MemoTableEntry(p.type,-1,-1,-1,p.size,p.closed))
 			.distinct().collect(Collectors.toList());
 	}
 	
@@ -271,15 +288,17 @@ public class CPlanMemoTable
 		public final long input1; 
 		public final long input2;
 		public final long input3;
+		public final int size;
 		public boolean closed = false;
-		public MemoTableEntry(TemplateType t, long in1, long in2, long in3) {
-			this(t, in1, in2, in3, false);
+		public MemoTableEntry(TemplateType t, long in1, long in2, long in3, int inlen) {
+			this(t, in1, in2, in3, inlen, false);
 		}
-		public MemoTableEntry(TemplateType t, long in1, long in2, long in3, boolean close) {
+		public MemoTableEntry(TemplateType t, long in1, long in2, long in3, int inlen, boolean close) {
 			type = t;
 			input1 = in1;
 			input2 = in2;
 			input3 = in3;
+			size = inlen;
 			closed = close;
 		}
 		public boolean isPlanRef(int index) {
@@ -324,7 +343,16 @@ public class CPlanMemoTable
 		}
 		@Override
 		public String toString() {
-			return type.name()+"("+input1+","+input2+","+input3+")";
+			StringBuilder sb = new StringBuilder();
+			sb.append(type.name());
+			sb.append("(");
+			for( int i=0; i<size; i++ ) {
+				if( i > 0 )
+					sb.append(",");
+				sb.append(input(i));
+			}
+			sb.append(")");
+			return sb.toString();
 		}
 	}
 	
@@ -332,13 +360,15 @@ public class CPlanMemoTable
 	{
 		public ArrayList<MemoTableEntry> plans = new ArrayList<MemoTableEntry>();
 		
-		public MemoTableEntrySet(TemplateType type, boolean close) {
-			plans.add(new MemoTableEntry(type, -1, -1, -1, close));
+		public MemoTableEntrySet(Hop hop, TemplateType type, boolean close) {
+			int size = (hop instanceof IndexingOp) ? 1 : hop.getInput().size();
+			plans.add(new MemoTableEntry(type, -1, -1, -1, size, close));
 		}
 		
-		public MemoTableEntrySet(TemplateType type, int pos, long hopID, boolean close) {
+		public MemoTableEntrySet(Hop hop, TemplateType type, int pos, long hopID, boolean close) {
+			int size = (hop instanceof IndexingOp) ? 1 : hop.getInput().size();
 			plans.add(new MemoTableEntry(type, (pos==0)?hopID:-1, 
-					(pos==1)?hopID:-1, (pos==2)?hopID:-1));
+					(pos==1)?hopID:-1, (pos==2)?hopID:-1, size));
 		}
 		
 		public void crossProduct(int pos, Long... refs) {
@@ -346,7 +376,7 @@ public class CPlanMemoTable
 			for( MemoTableEntry me : plans )
 				for( Long ref : refs )
 					tmp.add(new MemoTableEntry(me.type, (pos==0)?ref:me.input1, 
-						(pos==1)?ref:me.input2, (pos==2)?ref:me.input3));
+						(pos==1)?ref:me.input2, (pos==2)?ref:me.input3, me.size));
 			plans = tmp;
 		}
 		

http://git-wip-us.apache.org/repos/asf/systemml/blob/ee2b37e4/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
index 5cc18ea..82fedff 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
@@ -252,7 +252,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 			int ito = Math.min(i+3, fullAggs.size());
 			if( ito-i >= 2 ) {
 				MemoTableEntry me = new MemoTableEntry(TemplateType.MultiAggTpl,
-					fullAggs.get(i), fullAggs.get(i+1), ((ito-i)==3)?fullAggs.get(i+2):-1);
+					fullAggs.get(i), fullAggs.get(i+1), ((ito-i)==3)?fullAggs.get(i+2):-1, ito-i);
 				if( isValidMultiAggregate(memo, me) ) {
 					for( int j=i; j<ito; j++ ) {
 						memo.add(memo._hopRefs.get(fullAggs.get(j)), me);
@@ -352,7 +352,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 				continue;
 			Long[] aggs = info._aggregates.keySet().toArray(new Long[0]);
 			MemoTableEntry me = new MemoTableEntry(TemplateType.MultiAggTpl,
-				aggs[0], aggs[1], (aggs.length>2)?aggs[2]:-1);
+				aggs[0], aggs[1], (aggs.length>2)?aggs[2]:-1, aggs.length);
 			for( int i=0; i<aggs.length; i++ ) {
 				memo.add(memo._hopRefs.get(aggs[i]), me);
 				addBestPlan(aggs[i], me);

http://git-wip-us.apache.org/repos/asf/systemml/blob/ee2b37e4/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java
index 8c4cacd..cec0fb0 100644
--- a/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysml/runtime/util/UtilFunctions.java
@@ -21,7 +21,9 @@ package org.apache.sysml.runtime.util;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 
 import org.apache.commons.lang.ArrayUtils;
 import org.apache.sysml.parser.Expression.ValueType;
@@ -583,4 +585,13 @@ public class UtilFunctions
 				return true;
 		return false;
 	}
+	
+	@SafeVarargs
+	public static <T> Set<T> asSet(T[]... inputs) {
+		Set<T> ret = new HashSet<>();
+		for( T[] input : inputs )
+			for( T element : input )
+				ret.add(element);
+		return ret;
+	}
 }