You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/09/25 05:44:11 UTC

systemml git commit: [SYSTEMML-1934] Fix codegen optimizer (incorrect structural pruning)

Repository: systemml
Updated Branches:
  refs/heads/master c1db484d6 -> d01d13c4b


[SYSTEMML-1934] Fix codegen optimizer (incorrect structural pruning)

This patch fixes a severe codegen optimizer issue, where in special
cases the positions of sub problems were incorrectly set leading to
wrong mappings of optimal plans for subproblems to the global plan. We
now use a much simpler and more robust creation of these mappings. 

With this patch we now (1) find the optimal plans on scenarios where we
previously missed them (e.g., Mlogreg), and (2) structural pruning shows
more pruning effectiveness. For example, on GLM binomial probit, there
are 264,371 plans - with cost-based pruning this is reduced to 33,388
and with additional structural pruning further reduced to 9,574 plans.

Furthermore, this patch also improves the trace information of the
codegen optimizer as well as fixes an issue of statistic maintenance
without structural pruning.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d01d13c4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d01d13c4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d01d13c4

Branch: refs/heads/master
Commit: d01d13c4bc7b4da72ea399d76946cefda31fd4a4
Parents: c1db484
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun Sep 24 22:01:27 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sun Sep 24 22:01:27 2017 -0700

----------------------------------------------------------------------
 .../opt/PlanSelectionFuseCostBasedV2.java       | 20 +++--
 .../hops/codegen/opt/ReachabilityGraph.java     | 92 +++++++++-----------
 2 files changed, 55 insertions(+), 57 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d01d13c4/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 30631d0..7c27dcf 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -92,8 +92,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 	private static final double SPARSE_SAFE_SPARSITY_EST = 0.1;
 	
 	//optimizer configuration
-	public static boolean USE_COST_PRUNING = true;
-	public static boolean USE_STRUCTURAL_PRUNING = true;
+	public static boolean COST_PRUNING = true;
+	public static boolean STRUCTURAL_PRUNING = false;
 	
 	private static final IDSequence COST_ID = new IDSequence();
 	private static final TemplateRow ROW_TPL = new TemplateRow();
@@ -149,8 +149,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 			//prepare pruning helpers and prune memo table w/ determined mat points
 			StaticCosts costs = new StaticCosts(computeCosts, getComputeCost(computeCosts, memo), 
 				getReadCost(part, memo), getWriteCost(part.getRoots(), memo));
-			ReachabilityGraph rgraph = USE_STRUCTURAL_PRUNING ? new ReachabilityGraph(part, memo) : null;
-			if( USE_STRUCTURAL_PRUNING ) {
+			ReachabilityGraph rgraph = STRUCTURAL_PRUNING ? new ReachabilityGraph(part, memo) : null;
+			if( STRUCTURAL_PRUNING ) {
 				part.setMatPointsExt(rgraph.getSortedSearchSpace());
 				for( Long hopID : part.getPartition() )
 					memo.pruneRedundant(hopID, true, part.getMatPointsExt());
@@ -210,15 +210,19 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 			long pskip = 0; //skip after costing
 			
 			//skip plans with structural pruning
-			if( USE_STRUCTURAL_PRUNING && (rgraph!=null) && rgraph.isCutSet(plan) ) {
+			if( STRUCTURAL_PRUNING && (rgraph!=null) && rgraph.isCutSet(plan) ) {
 				//compute skip (which also acts as boundary for subproblems)
 				pskip = rgraph.getNumSkipPlans(plan);
+				if( LOG.isTraceEnabled() )
+					LOG.trace("Enum: Structural pruning for cut set: "+rgraph.getCutSet(plan));
 				
 				//start increment rgraph get subproblems
 				SubProblem[] prob = rgraph.getSubproblems(plan);
 				
 				//solve subproblems independently and combine into best plan
 				for( int j=0; j<prob.length; j++ ) {
+					if( LOG.isTraceEnabled() )
+						LOG.trace("Enum: Subproblem "+(j+1)+"/"+prob.length+": "+prob[j]);
 					boolean[] bestTmp = enumPlans(memo, part, 
 						costs, null, prob[j].freeMat, prob[j].offset, bestC);
 					LibSpoofPrimitives.vectWrite(bestTmp, plan, prob[j].freePos);
@@ -228,7 +232,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 				//the default code path; hence we postpone the skip after costing
 			}
 			//skip plans with branch and bound pruning (cost)
-			else if( USE_COST_PRUNING ) {
+			else if( COST_PRUNING ) {
 				double lbC = Math.max(costs._read, costs._compute) + costs._write
 					+ getMaterializationCost(part, matPoints, memo, plan);
 				if( lbC >= bestC ) {
@@ -241,7 +245,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 			}
 			
 			//cost assignment on hops. Stop early if exceeds bestC.
-			double pCBound = USE_COST_PRUNING ? bestC : Double.MAX_VALUE;
+			double pCBound = COST_PRUNING ? bestC : Double.MAX_VALUE;
 			double C = getPlanCost(memo, part, matPoints, plan, costs._computeCosts, pCBound);
 			if (LOG.isTraceEnabled())
 				LOG.trace("Enum: " + Arrays.toString(plan) + " -> " + C);
@@ -263,7 +267,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 		}
 		
 		if( DMLScript.STATISTICS ) {
-			Statistics.incrementCodegenEnumAllP((rgraph!=null)?len:0);
+			Statistics.incrementCodegenEnumAllP((rgraph!=null||!STRUCTURAL_PRUNING)?len:0);
 			Statistics.incrementCodegenEnumEval(numEvalPlans);
 			Statistics.incrementCodegenEnumEvalP(numEvalPartPlans);
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/d01d13c4/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java b/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java
index 0c829e8..fb7840b 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java
@@ -118,28 +118,26 @@ public class ReachabilityGraph
 		_cutSets = cutSets.stream()
 				.sorted(Comparator.comparing(p -> p.getRight()))
 				.map(p -> p.getLeft()).toArray(CutSet[]::new);
-	
+		
+		//created sorted order of materialization points
+		//(cut sets in predetermined order, all other points appended)
 		HashMap<InterestingPoint, Integer> probe = new HashMap<>();
 		ArrayList<InterestingPoint> lsearchSpace = new ArrayList<>();
 		for( CutSet cs : _cutSets ) {
-			cs.updatePos(lsearchSpace.size());
-			cs.updatePartitions(probe);
 			CollectionUtils.addAll(lsearchSpace, cs.cut);
-			for( InterestingPoint p: cs.cut )
-				probe.put(p, probe.size()-1);
+			for( InterestingPoint p : cs.cut )
+				probe.put(p, probe.size());
 		}
 		for( InterestingPoint p : part.getMatPointsExt() )
 			if( !probe.containsKey(p) ) {
 				lsearchSpace.add(p);
-				probe.put(p, probe.size()-1);
+				probe.put(p, probe.size());
 			}
 		_searchSpace = lsearchSpace.toArray(new InterestingPoint[0]);
 		
-		//materialize partition indices
-		for( CutSet cs : _cutSets ) {
-			cs.updatePartitionIndexes(probe);
-			cs.finalizePartition();
-		}
+		//finalize cut sets (update positions wrt search space)
+		for( CutSet cs : _cutSets )
+			cs.updatePositions(probe);
 		
 		//final sanity check of interesting points
 		if( _searchSpace.length != part.getMatPointsExt().length )
@@ -175,7 +173,7 @@ public class ReachabilityGraph
 	public long getNumSkipPlans(boolean[] plan) {
 		for( CutSet cs : _cutSets )
 			if( isCutSet(cs, plan) ) {
-				int pos = cs.posCut[cs.posCut.length-1];				
+				int pos = cs.posCut[cs.posCut.length-1];
 				return UtilFunctions.pow(2, plan.length-pos-1);
 			}
 		throw new RuntimeException("Failed to compute "
@@ -271,48 +269,44 @@ public class ReachabilityGraph
 			freePos = pos;
 			freeMat = mat;
 		}
+		
+		@Override
+		public String toString() {
+			return "SubProblem: "+Arrays.toString(freeMat)+"; "
+				+offset+"; "+Arrays.toString(freePos);
+		}
 	}
 	
-	public static class CutSet {
-		public InterestingPoint[] cut;
-		public InterestingPoint[] left;
-		public InterestingPoint[] right;
-		public int[] posCut;
-		public int[] posLeft;
-		public int[] posRight;
+	private static class CutSet {
+		private final InterestingPoint[] cut;
+		private final InterestingPoint[] left;
+		private final InterestingPoint[] right;
+		private int[] posCut;
+		private int[] posLeft;
+		private int[] posRight;
 		
-		public CutSet(InterestingPoint[] cutPoints, 
+		private CutSet(InterestingPoint[] cutPoints, 
 				InterestingPoint[] l, InterestingPoint[] r) {
 			cut = cutPoints;
-			left = l;
-			right = r;
+			left = (InterestingPoint[]) ArrayUtils.addAll(cut, l);
+			right = (InterestingPoint[]) ArrayUtils.addAll(cut, r);
 		}
 		
-		public void updatePos(int index) {
-			posCut = new int[cut.length];
-			for(int i=0; i<posCut.length; i++)
-				posCut[i] = index + i;
-		}
-		
-		public void updatePartitions(HashMap<InterestingPoint,Integer> blacklist) {
-			left = Arrays.stream(left).filter(p -> !blacklist.containsKey(p))
-				.toArray(InterestingPoint[]::new);
-			right = Arrays.stream(right).filter(p -> !blacklist.containsKey(p))
-				.toArray(InterestingPoint[]::new);
-		}
-		
-		public void updatePartitionIndexes(HashMap<InterestingPoint,Integer> probe) {
-			posLeft = new int[left.length];
-			for(int i=0; i<left.length; i++)
-				posLeft[i] = probe.get(left[i]);
-			posRight = new int[right.length];
-			for(int i=0; i<right.length; i++)
-				posRight[i] = probe.get(right[i]);
-		}
-		
-		public void finalizePartition() {
-			left = (InterestingPoint[]) ArrayUtils.addAll(cut, left);
-			right = (InterestingPoint[]) ArrayUtils.addAll(cut, right);
+		private void updatePositions(HashMap<InterestingPoint,Integer> probe) {
+			int lenCut = cut.length;
+			posCut = new int[lenCut];
+			for(int i=0; i<lenCut; i++)
+				posCut[i] = probe.get(cut[i]);
+			
+			int lenLeft = left.length - cut.length;
+			posLeft = new int[lenLeft];
+			for(int i=0; i<lenLeft; i++)
+				posLeft[i] = probe.get(left[lenCut+i]);
+			
+			int lenRight = right.length - cut.length;
+			posRight = new int[lenRight];
+			for(int i=0; i<lenRight; i++)
+				posRight[i] = probe.get(right[lenCut+i]);
 		}
 		
 		@Override
@@ -329,12 +323,12 @@ public class ReachabilityGraph
 		private long _ID;
 		private InterestingPoint _p;
 		
-		public NodeLink(InterestingPoint p) {
+		private NodeLink(InterestingPoint p) {
 			_ID = _seqID.getNextID();
 			_p = p;
 		} 
 		
-		public void addInput(NodeLink in) {
+		private void addInput(NodeLink in) {
 			_inputs.add(in);
 		}