You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/04/07 04:15:16 UTC

[3/3] incubator-systemml git commit: [SYSTEMML-1288] Extended code generator (multi-agg across partitions)

[SYSTEMML-1288] Extended code generator (multi-agg across partitions)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/9820f4c5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/9820f4c5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/9820f4c5

Branch: refs/heads/master
Commit: 9820f4c5293c69873f68544748507b6473948f12
Parents: 7c15339
Author: Matthias Boehm <mb...@gmail.com>
Authored: Thu Apr 6 20:57:07 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Thu Apr 6 21:16:31 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeMultiAgg.java |  5 +-
 .../template/PlanSelectionFuseCostBased.java    | 54 +++++++++++++++++++-
 .../hops/codegen/template/TemplateCell.java     |  3 +-
 3 files changed, 59 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9820f4c5/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java
index 7ec07a6..d9502be 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeMultiAgg.java
@@ -106,7 +106,10 @@ public class CNodeMultiAgg extends CNodeTpl
 		for( int i=0; i<_outputs.size(); i++ ) {
 			CNode out = _outputs.get(i);
 			String tmpOut = getAggTemplate(i);
-			tmpOut = tmpOut.replace("%IN%", out.getVarname());
+			//get variable name (w/ handling of direct consumption of inputs)
+			String varName = (out instanceof CNodeData && ((CNodeData)out).getHopID()==
+				((CNodeData)_inputs.get(0)).getHopID()) ? "a" : out.getVarname(); 
+			tmpOut = tmpOut.replace("%IN%", varName);
 			tmpOut = tmpOut.replace("%IX%", String.valueOf(i));
 			sb.append(tmpOut);
 		}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9820f4c5/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
index 50d6ff1..151dab2 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
@@ -34,7 +34,9 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.hops.AggBinaryOp;
 import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.DataOp;
 import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.ParameterizedBuiltinOp;
@@ -82,12 +84,15 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 			if( LOG.isTraceEnabled() )
 				LOG.trace("Partition materialization points: "+Arrays.toString(M.toArray(new Long[0])));
 			
-			//step 3: create composite templates entries
+			//step 3: create composite templates (within the partition)
 			createAndAddMultiAggPlans(memo, partition, R);
 			
 			//step 4: plan enumeration and plan selection
 			selectPlans(memo, partition, R, M);
 		}
+		
+		//step 5: add composite templates (across partitions)
+		createAndAddMultiAggPlans(memo, roots);
 	
 		//take all distinct best plans
 		for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() )
@@ -217,6 +222,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 			&& partition.contains(hop.getHopID());
 	}
 	
+	//within-partition multi-agg templates
 	private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R)
 	{
 		//create index of plans that reference full aggregates to avoid circular dependencies
@@ -262,6 +268,30 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 		}
 	}
 	
+	//across-partition multi-agg templates
+	private static void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots)
+	{
+		//#1: collect full aggregations over shared inputs (otherwise never fused)
+		HashMap<Long, ArrayList<Long>> fullAggs = new HashMap<Long, ArrayList<Long>>();
+		Hop.resetVisitStatus(roots);
+		for( Hop hop : roots )
+			rCollectAggregatesSharedRead(hop, fullAggs);
+		
+		//construct and add multiagg template plans (w/ max 3 aggregations)
+		for( Entry<Long, ArrayList<Long>> e : fullAggs.entrySet() ) {
+			if( e.getValue().size()<=1 )
+				continue;
+			ArrayList<Long> aggs = e.getValue();
+			MemoTableEntry me = new MemoTableEntry(TemplateType.MultiAggTpl,
+				aggs.get(0), aggs.get(1), (aggs.size()>2)?aggs.get(2):-1);
+			for( int i=0; i<aggs.size(); i++ ) {
+				memo.add(memo._hopRefs.get(aggs.get(i)), me);
+				if( LOG.isTraceEnabled() )
+					LOG.trace("Added multiagg* plan: "+aggs.get(i)+" "+me);
+			}
+		}
+	}
+	
 	private static boolean isValidMultiAggregate(CPlanMemoTable memo, MemoTableEntry me) {
 		//ensure that aggregates are independent of each other, i.e.,
 		//they to not have potentially transitive parent child references
@@ -285,6 +315,28 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 		return ret;
 	}
 	
+	private static void rCollectAggregatesSharedRead(Hop current, HashMap<Long, ArrayList<Long>> aggs) {
+		if( current.isVisited() )
+			return;
+		
+		//collect all applicable full aggregations per read
+		if( HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX)
+			&& ((AggUnaryOp)current).getDirection()==Direction.RowCol
+			&& current.getInput().get(0) instanceof DataOp )
+		{
+			Hop input = current.getInput().get(0);
+			if( !aggs.containsKey(input.getHopID()) )
+				aggs.put(input.getHopID(), new ArrayList<Long>());
+			aggs.get(input.getHopID()).add(current.getHopID());
+		}
+		
+		//recursively process children
+		for( Hop c : current.getInput() )
+			rCollectAggregatesSharedRead(c, aggs);
+		
+		current.setVisited();
+	}
+	
 	private void selectPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R, ArrayList<Long> M) 
 	{
 		//if no materialization points, use basic fuse-all w/ partition awareness

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9820f4c5/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index e3c12d5..885d3db 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -29,6 +29,7 @@ import java.util.stream.Collectors;
 import org.apache.sysml.hops.AggBinaryOp;
 import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.DataOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
@@ -149,7 +150,7 @@ public class TemplateCell extends TemplateBase
 		MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.CellTpl);
 		for( int i=0; i<hop.getInput().size(); i++ ) {
 			Hop c = hop.getInput().get(i);
-			if( me.isPlanRef(i) )
+			if( me!=null && me.isPlanRef(i) && !(c instanceof DataOp) )
 				rConstructCplan(c, memo, tmp, inHops, compileLiterals);
 			else {
 				CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);