You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/10/22 00:15:01 UTC

systemml git commit: [SYSTEMML-1968] Improved codegen optimizer (cost, mat points, pruning)

Repository: systemml
Updated Branches:
  refs/heads/master 6de8f051d -> 311e4aac9


[SYSTEMML-1968] Improved codegen optimizer (cost, mat points, pruning)

This patch improves the cost-based codegen optimizer to address wrong
fusion decision for large-scale computations. In detail, this includes:

1) Cost model: The cost model now accounts the broadcast cost for side
inputs in distributed spark operations. Furthermore, this also includes
a fix of calculating the compute costs in case of a mix of row and cell
operations of different dimensions. 

2) Interesting points: To enable the reasoning about side inputs, we now
also consider template switches from cell to row templates as
interesting points.

3) Pruning of row templates: The above changes also revealed hidden
issues in the pruning of unnecessary row templates (conversion to cell
templates), which mistakenly removed necessary row templates, which
ultimately led to runtime errors.

On a large-scale scenario of L2SVM over a 200M x 100 dense input
(160GB), this patch improved the end-to-end runtime for 20 outer
iterations from 942s to 273s (w/o codegen: 644s).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/311e4aac
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/311e4aac
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/311e4aac

Branch: refs/heads/master
Commit: 311e4aac9833397908a083d0a48d5bd3ba086283
Parents: 6de8f05
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sat Oct 21 16:41:53 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sat Oct 21 17:15:38 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/opt/PlanAnalyzer.java    |   2 +-
 .../opt/PlanSelectionFuseCostBasedV2.java       | 131 ++++++++++---------
 .../hops/codegen/template/CPlanMemoTable.java   |  25 ++--
 .../runtime/codegen/LibSpoofPrimitives.java     |   6 +-
 4 files changed, 91 insertions(+), 73 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/311e4aac/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java
index 9910814..7d522b3 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanAnalyzer.java
@@ -267,7 +267,7 @@ public class PlanAnalyzer
 			for( int i=0; i<3; i++ ) {
 				if( refs[i] < 0 ) continue;
 				List<TemplateType> tmp = memo.getDistinctTemplateTypes(hopID, i, true);
-				if( memo.containsNotIn(refs[i], tmp, true, true) )
+				if( memo.containsNotIn(refs[i], tmp, true) )
 					ret.add(new InterestingPoint(DecisionType.TEMPLATE_CHANGE, hopID, refs[i]));
 			}
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/311e4aac/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index d2ed3ac..10875e8 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -86,6 +86,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 	//to cover result allocation, write into main memory, and potential evictions
 	private static final double WRITE_BANDWIDTH = 2d*1024*1024*1024;  //2GB/s
 	private static final double READ_BANDWIDTH = 32d*1024*1024*1024;  //32GB/s
+	private static final double READ_BANDWIDTH_BROADCAST = WRITE_BANDWIDTH/4;
 	private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 //2GFLOPs/core
 		* InfrastructureAnalyzer.getLocalParallelism();
 	
@@ -146,7 +147,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 				getComputeCosts(memo.getHopRefs().get(hopID), computeCosts);
 			
 			//prepare pruning helpers and prune memo table w/ determined mat points
-			StaticCosts costs = new StaticCosts(computeCosts, getComputeCost(computeCosts, memo), 
+			StaticCosts costs = new StaticCosts(computeCosts, sumComputeCost(computeCosts), 
 				getReadCost(part, memo), getWriteCost(part.getRoots(), memo));
 			ReachabilityGraph rgraph = STRUCTURAL_PRUNING ? new ReachabilityGraph(part, memo) : null;
 			if( STRUCTURAL_PRUNING ) {
@@ -339,14 +340,9 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 		return costs;
 	}
 	
-	private static double getComputeCost(HashMap<Long, Double> computeCosts, CPlanMemoTable memo) {
-		double costs = 0;
-		for( Entry<Long,Double> e : computeCosts.entrySet() ) {
-			Hop mainInput = memo.getHopRefs()
-				.get(e.getKey()).getInput().get(0);
-			costs += getSize(mainInput) * e.getValue() / COMPUTE_BANDWIDTH;
-		}
-		return costs;
+	private static double sumComputeCost(HashMap<Long, Double> computeCosts) {
+		return computeCosts.values().stream()
+			.mapToDouble(d -> d/COMPUTE_BANDWIDTH).sum();
 	}
 	
 	private static long getSize(Hop hop) {
@@ -567,33 +563,39 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 		}
 	}
 	
-	private static boolean isRowTemplateWithoutAggOrVects(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
-		//consider all aggregations other than root operation
-		MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
-		boolean ret = true;
-		for(int i=0; i<3; i++)
-			if( me.isPlanRef(i) )
-				ret &= rIsRowTemplateWithoutAggOrVects(memo, 
-					current.getInput().get(i), visited);
-		return ret;
+	private static HashSet<Long> getRowAggOpsWithRowRef(CPlanMemoTable memo, PlanPartition part) {
+		HashSet<Long> refAggs = new HashSet<>();
+		for( Long hopID : part.getPartition() ) {
+			if( !memo.contains(hopID, TemplateType.ROW) ) continue;
+			MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW);
+			for(int i=0; i<3; i++)
+				if( me.isPlanRef(i) && memo.contains(me.input(i), TemplateType.ROW) 
+					&& isRowAggOp(memo.getHopRefs().get(me.input(i))))
+					refAggs.add(me.input(i));
+		}
+		return refAggs;
 	}
 	
-	private static boolean rIsRowTemplateWithoutAggOrVects(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
+	private static boolean rIsRowTemplateWithoutAggOrVects(CPlanMemoTable memo, Hop current, HashSet<Long> visited, boolean inclRoot) {
 		if( visited.contains(current.getHopID()) )
 			return true;
 		
-		boolean ret = true;
 		MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
-		for(int i=0; i<3; i++)
+		boolean ret = !inclRoot || !isRowAggOp(current);
+		for(int i=0; i<3 && ret; i++)
 			if( me!=null && me.isPlanRef(i) )
-				ret &= rIsRowTemplateWithoutAggOrVects(memo, current.getInput().get(i), visited);
-		ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp
-			|| HopRewriteUtils.isBinary(current, OpOp2.CBIND));
+				ret &= rIsRowTemplateWithoutAggOrVects(memo, 
+					current.getInput().get(i), visited, true);
 		
 		visited.add(current.getHopID());
 		return ret;
 	}
 	
+	private static boolean isRowAggOp(Hop hop){
+		return (hop instanceof AggUnaryOp || hop instanceof AggBinaryOp
+			|| HopRewriteUtils.isBinary(hop, OpOp2.CBIND));
+	}
+	
 	private static void pruneInvalidAndSpecialCasePlans(CPlanMemoTable memo, PlanPartition part) 
 	{	
 		//prune invalid row entries w/ violated blocksize constraint
@@ -613,9 +615,8 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 						&& HopRewriteUtils.isTransposeOperation(in));
 				if( isSpark && !validNcol ) {
 					List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW);
-					memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(blacklist));
-					if( !memo.contains(hopID) )
-						memo.removeAllRefTo(hopID);
+					memo.remove(memo.getHopRefs().get(hopID), TemplateType.ROW);
+					memo.removeAllRefTo(hopID, TemplateType.ROW);
 					if( LOG.isTraceEnabled() ) {
 						LOG.trace("Removed row memo table entries w/ violated blocksize constraint ("+hopID+"): "
 							+ Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
@@ -625,10 +626,11 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 		}
 		
 		//prune row aggregates with pure cellwise operations
+		HashSet<Long> refAggs = getRowAggOpsWithRowRef(memo, part);
 		for( Long hopID : part.getPartition() ) {
 			MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW);
 			if( me != null && me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL)
-				&& isRowTemplateWithoutAggOrVects(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) {
+				&& rIsRowTemplateWithoutAggOrVects(memo, memo.getHopRefs().get(hopID), new HashSet<Long>(), refAggs.contains(hopID)) ) {
 				List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW); 
 				memo.remove(memo.getHopRefs().get(hopID), new HashSet<>(blacklist));
 				if( LOG.isTraceEnabled() ) {
@@ -698,28 +700,25 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 		//i.e., plans that become invalid after the previous pruning step
 		long hopID = current.getHopID();
 		if( part.getPartition().contains(hopID) && memo.contains(hopID, TemplateType.ROW) ) {
-			for( MemoTableEntry me : memo.get(hopID) ) {
-				if( me.type==TemplateType.ROW ) {
-					//convert leaf node with pure vector inputs
-					if( !me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current) ) {
+			for( MemoTableEntry me : memo.get(hopID, TemplateType.ROW) ) {
+				//convert leaf node with pure vector inputs
+				if( !me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current) ) {
+					me.type = TemplateType.CELL;
+					if( LOG.isTraceEnabled() )
+						LOG.trace("Converted leaf memo table entry from row to cell: "+me);
+				}
+				
+				//convert inner node without row template input
+				if( me.hasPlanRef() && !ROW_TPL.open(current) ) {
+					boolean hasRowInput = false;
+					for( int i=0; i<3; i++ )
+						if( me.isPlanRef(i) )
+							hasRowInput |= memo.contains(me.input(i), TemplateType.ROW);
+					if( !hasRowInput ) {
 						me.type = TemplateType.CELL;
 						if( LOG.isTraceEnabled() )
-							LOG.trace("Converted leaf memo table entry from row to cell: "+me);
-					}
-					
-					//convert inner node without row template input
-					if( me.hasPlanRef() && !ROW_TPL.open(current) ) {
-						boolean hasRowInput = false;
-						for( int i=0; i<3; i++ )
-							if( me.isPlanRef(i) )
-								hasRowInput |= memo.contains(me.input(i), TemplateType.ROW);
-						if( !hasRowInput ) {
-							me.type = TemplateType.CELL;
-							if( LOG.isTraceEnabled() )
-								LOG.trace("Converted inner memo table entry from row to cell: "+me);	
-						}
+							LOG.trace("Converted inner memo table entry from row to cell: "+me);	
 					}
-					
 				}
 			}
 		}
@@ -834,14 +833,16 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 				String type = (best !=null) ? best.type.name() : "HOP";
 				LOG.trace("Cost vector ("+type+" "+currentHopId+"): "+costVect);
 			}
-			double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH //time for output write
-				+ Math.max(costVect.getSumInputSizes() * 8 / READ_BANDWIDTH,
-				costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH);
+			double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH
+				+ Math.max(costVect.getInputSize() * 8 / READ_BANDWIDTH,
+				costVect.computeCosts/ COMPUTE_BANDWIDTH);
+			//read correction for distributed computation
+			Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID());
+			if( driver.getMemEstimate() > OptimizerUtils.getLocalMemBudget() )
+				tmpCosts += costVect.getSideInputSize() * 8 / READ_BANDWIDTH_BROADCAST;
 			//sparsity correction for outer-product template (and sparse-safe cell)
-			if( best != null && best.type == TemplateType.OUTER ) {
-				Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID());
+			if( best != null && best.type == TemplateType.OUTER )
 				tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST;
-			}
 			costs += tmpCosts;
 		}
 		//add costs for non-partition read in the middle of fused operator
@@ -978,12 +979,9 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 			costs = 1;
 		}
 		else if( current instanceof AggBinaryOp ) {
-			//outer product template
-			if( HopRewriteUtils.isOuterProductLikeMM(current) )
-				costs = 2 * current.getInput().get(0).getDim2();
-			//row template w/ matrix-vector or matrix-matrix
-			else
-				costs = 2 * current .getDim2();
+			//outer product template w/ matrix-matrix 
+			//or row template w/ matrix-vector or matrix-matrix
+			costs = 2 * current.getInput().get(0).getDim2();
 		}
 		else if( current instanceof AggUnaryOp) {
 			switch(((AggUnaryOp)current).getOp()) {
@@ -993,10 +991,15 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 			case MAX:    costs = 1; break;
 			default:
 				LOG.warn("Cost model not "
-					+ "implemented yet for: "+((AggUnaryOp)current).getOp());			
+					+ "implemented yet for: "+((AggUnaryOp)current).getOp());
 			}
 		}
 		
+		//scale by current output size in order to correctly reflect
+		//a mix of row and cell operations in the same fused operator
+		//(e.g., row template with fused column vector operations)
+		costs *= getSize(current);
+		
 		computeCosts.put(current.getHopID(), costs);
 	}
 	
@@ -1025,8 +1028,14 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
 			//ensures that input sizes are not double counted
 			inSizes.put(hopID, inputSize);
 		}
-		public double getSumInputSizes() {
+		public double getInputSize() {
+			return inSizes.values().stream()
+				.mapToDouble(d -> d.doubleValue()).sum();
+		}
+		public double getSideInputSize() {
+			double max = getMaxInputSize();
 			return inSizes.values().stream()
+				.filter(d -> d < max)
 				.mapToDouble(d -> d.doubleValue()).sum();
 		}
 		public double getMaxInputSize() {

http://git-wip-us.apache.org/repos/asf/systemml/blob/311e4aac/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
index 99ffc8d..5eedc7b 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
@@ -95,11 +95,10 @@ public class CPlanMemoTable
 			.anyMatch(p -> (!checkClose||!p.isClosed()) && probe.contains(p.type));
 	}
 	
-	public boolean containsNotIn(long hopID, Collection<TemplateType> types, 
-		boolean checkChildRefs, boolean excludeCell) {
+	public boolean containsNotIn(long hopID, 
+		Collection<TemplateType> types, boolean checkChildRefs) {
 		return contains(hopID) && get(hopID).stream()
-			.anyMatch(p -> (!checkChildRefs || p.hasPlanRef()) 
-				&& (!excludeCell || p.type!=TemplateType.CELL)
+			.anyMatch(p -> (!checkChildRefs || p.hasPlanRef())
 				&& p.isValid() && !types.contains(p.type));
 	}
 	
@@ -153,14 +152,22 @@ public class CPlanMemoTable
 			.removeIf(p -> blackList.contains(p));
 	}
 	
+	public void remove(Hop hop, TemplateType type) {
+		_plans.get(hop.getHopID())
+			.removeIf(p -> p.type == type);
+	}
+	
 	public void removeAllRefTo(long hopID) {
+		removeAllRefTo(hopID, null);
+	}
+	
+	public void removeAllRefTo(long hopID, TemplateType type) {
 		//recursive removal of references
 		for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() ) {
-			if( !e.getValue().isEmpty() ) {
-				e.getValue().removeIf(p -> p.hasPlanRefTo(hopID));
-				if( e.getValue().isEmpty() )
-					removeAllRefTo(e.getKey());
-			}
+			if( e.getValue().isEmpty() || e.getKey()==hopID ) 
+				continue;
+			e.getValue().removeIf(p -> p.hasPlanRefTo(hopID)
+				&& (type==null || p.type==type));
 		}
 	}
 	

http://git-wip-us.apache.org/repos/asf/systemml/blob/311e4aac/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
index 7624d96..91fde5e 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
@@ -1788,11 +1788,13 @@ public class LibSpoofPrimitives
 	//dynamic memory management
 	
 	public static void setupThreadLocalMemory(int numVectors, int len) {
-		setupThreadLocalMemory(numVectors, len, -1);
+		if( numVectors > 0 )
+			setupThreadLocalMemory(numVectors, len, -1);
 	}
 	
 	public static void setupThreadLocalMemory(int numVectors, int len, int len2) {
-		memPool.set(new VectorBuffer(numVectors, len, len2));
+		if( numVectors > 0 )
+			memPool.set(new VectorBuffer(numVectors, len, len2));
 	}
 	
 	public static void cleanupThreadLocalMemory() {