You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/08/05 07:55:49 UTC

[2/3] systemml git commit: [SYSTEMML-1741, 1536, 1296] New cost-based codegen optimizer (V2)

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
new file mode 100644
index 0000000..2fa0de7
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -0,0 +1,1100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Map.Entry;
+import java.util.stream.Collectors;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.hops.AggBinaryOp;
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.IndexingOp;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.TernaryOp;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.hops.codegen.opt.ReachabilityGraph.SubProblem;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
+import org.apache.sysml.hops.codegen.template.TemplateOuterProduct;
+import org.apache.sysml.hops.codegen.template.TemplateRow;
+import org.apache.sysml.hops.codegen.template.TemplateUtils;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
+import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
+import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysml.utils.Statistics;
+
+/**
+ * This cost-based plan selection algorithm chooses fused operators
+ * based on the DAG structure and resulting overall costs. This includes
+ * holistic decisions on 
+ * <ul>
+ *   <li>Materialization points per consumer</li>
+ *   <li>Sparsity exploitation and operator ordering</li>
+ *   <li>Decisions on overlapping template types</li>
+ *   <li>Decisions on multi-aggregates with shared reads</li>
+ *   <li>Constraints (e.g., memory budgets and block sizes)</li>  
+ * </ul>
+ * 
+ */
+public class PlanSelectionFuseCostBasedV2 extends PlanSelection
+{	
+	private static final Log LOG = LogFactory.getLog(PlanSelectionFuseCostBasedV2.class.getName());
+	
+	//common bandwidth characteristics, with a conservative write bandwidth in order 
+	//to cover result allocation, write into main memory, and potential evictions
+	private static final double WRITE_BANDWIDTH = 2d*1024*1024*1024;  //2GB/s
+	private static final double READ_BANDWIDTH = 32d*1024*1024*1024;  //32GB/s
+	private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 //2GFLOPs/core
+		* InfrastructureAnalyzer.getLocalParallelism();
+	
+	//sparsity estimate for unknown sparsity to prefer sparse-safe fusion plans
+	private static final double SPARSE_SAFE_SPARSITY_EST = 0.1;
+	
+	//optimizer configuration
+	private static final boolean USE_COST_PRUNING = true;
+	private static final boolean USE_STRUCTURAL_PRUNING = true;
+	
+	private static final IDSequence COST_ID = new IDSequence();
+	private static final TemplateRow ROW_TPL = new TemplateRow();
+	private static final BasicPlanComparator BASE_COMPARE = new BasicPlanComparator();
+	private final TypedPlanComparator _typedCompare = new TypedPlanComparator();
+	
+	@Override
+	public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) 
+	{
+		//step 1: analyze connected partitions (nodes, roots, mat points)
+		Collection<PlanPartition> parts = PlanAnalyzer.analyzePlanPartitions(memo, roots, true);
+		
+		//step 2: optimize individual plan partitions
+		for( PlanPartition part : parts ) {
+			//create composite templates (within the partition)
+			createAndAddMultiAggPlans(memo, part.getPartition(), part.getRoots());
+			
+			//plan enumeration and plan selection
+			selectPlans(memo, part);
+		}
+		
+		//step 3: add composite templates (across partitions)
+		createAndAddMultiAggPlans(memo, roots);
+		
+		//take all distinct best plans
+		for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() )
+			memo.setDistinct(e.getKey(), e.getValue());
+	}
+	
+	private void selectPlans(CPlanMemoTable memo, PlanPartition part) 
+	{
+		//prune row aggregates with pure cellwise operations
+		for( Long hopID : part.getRoots() ) {
+			MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW);
+			if( me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL)
+				&& isRowTemplateWithoutAgg(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) {
+				List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW); 
+				memo.remove(memo.getHopRefs().get(hopID), new HashSet<MemoTableEntry>(blacklist));
+				if( LOG.isTraceEnabled() ) {
+					LOG.trace("Removed row memo table entries w/o aggregation: "
+						+ Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
+				}
+			}
+		}
+		
+		//prune suboptimal outer product plans that are dominated by outer product plans w/ same number of 
+		//references but better fusion properties (e.g., for the patterns Y=X*(U%*%t(V)) and sum(Y*(U2%*%t(V2))), 
+		//we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern.
+		for( Long hopID : part.getPartition() ) {
+			if( memo.countEntries(hopID, TemplateType.OUTER) == 2 ) {
+				List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OUTER);
+				MemoTableEntry me1 = entries.get(0);
+				MemoTableEntry me2 = entries.get(1);
+				MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
+				if( rmEntry != null ) {
+					memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
+					memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
+					if( LOG.isTraceEnabled() )
+						LOG.trace("Removed dominated outer product memo table entry: " + rmEntry);
+				}
+			}
+		}
+		
+		//if no materialization points, use basic fuse-all w/ partition awareness
+		if( part.getMatPointsExt() == null || part.getMatPointsExt().length==0 ) {
+			for( Long hopID : part.getRoots() )
+				rSelectPlansFuseAll(memo, 
+					memo.getHopRefs().get(hopID), null, part.getPartition());
+		}
+		else {
+			//obtain hop compute costs per cell once
+			HashMap<Long, Double> computeCosts = new HashMap<Long, Double>();
+			for( Long hopID : part.getRoots() )
+				rGetComputeCosts(memo.getHopRefs().get(hopID), part.getPartition(), computeCosts);
+			
+			//prepare pruning helpers and prune memo table w/ determined mat points
+			StaticCosts costs = new StaticCosts(computeCosts, getComputeCost(computeCosts, memo), 
+				getReadCost(part, memo), getWriteCost(part.getRoots(), memo));
+			ReachabilityGraph rgraph = USE_STRUCTURAL_PRUNING ? new ReachabilityGraph(part, memo) : null;
+			if( USE_STRUCTURAL_PRUNING ) {
+				part.setMatPointsExt(rgraph.getSortedSearchSpace());
+				for( Long hopID : part.getPartition() )
+					memo.pruneRedundant(hopID, true, part.getMatPointsExt());
+			}
+			
+			//enumerate and cost plans, returns optional plan
+			boolean[] bestPlan = enumPlans(memo, part, costs, rgraph, 
+					part.getMatPointsExt(), 0, Double.MAX_VALUE);
+			
+			//prune memo table wrt best plan and select plans
+			HashSet<Long> visited = new HashSet<Long>();
+			for( Long hopID : part.getRoots() )
+				rPruneSuboptimalPlans(memo, memo.getHopRefs().get(hopID), 
+					visited, part, part.getMatPointsExt(), bestPlan);
+			HashSet<Long> visited2 = new HashSet<Long>();
+			for( Long hopID : part.getRoots() )
+				rPruneInvalidPlans(memo, memo.getHopRefs().get(hopID), 
+					visited2, part, bestPlan);
+			
+			for( Long hopID : part.getRoots() )
+				rSelectPlansFuseAll(memo, 
+					memo.getHopRefs().get(hopID), null, part.getPartition());
+		}
+	}
+	
+	/**
+	 * Core plan enumeration algorithm, invoked recursively for conditionally independent
+	 * subproblems. This algorithm fully explores the exponential search space of 2^m,
+	 * where m is the number of interesting materialization points. We iterate over
+	 * a linearized search space without every instantiating the search tree. Furthermore,
+	 * in order to reduce the enumeration overhead, we apply two high-impact pruning
+	 * techniques (1) pruning by evolving lower/upper cost bounds, and (2) pruning by
+	 * conditional structural properties (so-called cutsets of interesting points). 
+	 * 
+	 * @param memo memoization table of partial fusion plans
+	 * @param part connected component (partition) of partial fusion plans with all necessary meta data
+	 * @param costs summary of static costs (e.g., partition reads, writes, and compute costs per operator)
+	 * @param rgraph reachability graph of interesting materialization points
+	 * @param matPoints sorted materialization points (defined the search space)
+	 * @param off offset for recursive invocation, indicating the fixed plan part
+	 * @param bestC currently known best plan costs (used of upper bound)
+	 * @return optimal assignment of materialization points
+	 */
+	private static boolean[] enumPlans(CPlanMemoTable memo, PlanPartition part, StaticCosts costs, 
+		ReachabilityGraph rgraph, InterestingPoint[] matPoints, int off, double bestC) 
+	{
+		//scan linearized search space, w/ skips for branch and bound pruning
+		//and structural pruning (where we solve conditionally independent problems)
+		//bestC is monotonically non-increasing and serves as the upper bound
+		long len = (long)Math.pow(2, matPoints.length-off);
+		boolean[] bestPlan = null;
+		int numEvalPlans = 0;
+		
+		for( long i=0; i<len; i++ ) {
+			//construct assignment
+			boolean[] plan = createAssignment(matPoints.length-off, off, i);
+			long pskip = 0; //skip after costing
+			
+			//skip plans with structural pruning
+			if( USE_STRUCTURAL_PRUNING && (rgraph!=null) && rgraph.isCutSet(plan) ) {
+				//compute skip (which also acts as boundary for subproblems)
+				pskip = rgraph.getNumSkipPlans(plan);
+				
+				//start increment rgraph get subproblems
+				SubProblem[] prob = rgraph.getSubproblems(plan);
+				
+				//solve subproblems independently and combine into best plan
+				for( int j=0; j<prob.length; j++ ) {
+					boolean[] bestTmp = enumPlans(memo, part, 
+						costs, null, prob[j].freeMat, prob[j].offset, bestC);
+					LibSpoofPrimitives.vectWrite(bestTmp, plan, prob[j].freePos);
+				}
+				
+				//note: the overall plan costs are evaluated in full, which reused
+				//the default code path; hence we postpone the skip after costing
+			}
+			//skip plans with branch and bound pruning (cost)
+			else if( USE_COST_PRUNING ) {
+				double lbC = Math.max(costs._read, costs._compute) + costs._write 
+					+ getMaterializationCost(part, matPoints, memo, plan);
+				if( lbC >= bestC ) {
+					long skip = getNumSkipPlans(plan);
+					if( LOG.isTraceEnabled() )
+						LOG.trace("Enum: Skip "+skip+" plans (by cost).");
+					i += skip - 1;
+					continue;
+				}
+			}
+			
+			//cost assignment on hops
+			double C = getPlanCost(memo, part, matPoints, plan, costs._computeCosts);
+			numEvalPlans ++;
+			if( LOG.isTraceEnabled() )
+				LOG.trace("Enum: "+Arrays.toString(plan)+" -> "+C);
+			
+			//cost comparisons
+			if( bestPlan == null || C < bestC ) {
+				bestC = C;
+				bestPlan = plan;
+				if( LOG.isTraceEnabled() )
+					LOG.trace("Enum: Found new best plan.");
+			}
+			
+			//post skipping
+			i += pskip;
+			if( pskip !=0 && LOG.isTraceEnabled() )
+				LOG.trace("Enum: Skip "+pskip+" plans (by structure).");
+		}
+		
+		if( DMLScript.STATISTICS )
+			Statistics.incrementCodegenFPlanCompile(numEvalPlans);
+		if( LOG.isTraceEnabled() )
+			LOG.trace("Enum: Optimal plan: "+Arrays.toString(bestPlan));
+		
+		//copy best plan w/o fixed offset plan
+		return Arrays.copyOfRange(bestPlan, off, bestPlan.length);
+	}
+	
+	private static boolean[] createAssignment(int len, int off, long pos) {
+		boolean[] ret = new boolean[off+len];
+		Arrays.fill(ret, 0, off, true);
+		long tmp = pos;
+		for( int i=0; i<len; i++ ) {
+			ret[off+i] = (tmp >= Math.pow(2, len-i-1));
+			tmp %= Math.pow(2, len-i-1);
+		}
+		return ret;	
+	}
+	
+	private static long getNumSkipPlans(boolean[] plan) {
+		int pos = ArrayUtils.lastIndexOf(plan, true);
+		return (long) Math.pow(2, plan.length-pos-1);
+	}
+	
+	private static double getMaterializationCost(PlanPartition part, InterestingPoint[] M, CPlanMemoTable memo, boolean[] plan) {
+		double costs = 0;
+		//currently active materialization points
+		HashSet<Long> matTargets = new HashSet<>();
+		for( int i=0; i<plan.length; i++ ) {
+			long hopID = M[i].getToHopID();
+			if( plan[i] && !matTargets.contains(hopID) ) {
+				matTargets.add(hopID);
+				Hop hop = memo.getHopRefs().get(hopID);
+				long size = getSize(hop);
+				costs += size * 8 / WRITE_BANDWIDTH + 
+						size * 8 / READ_BANDWIDTH;
+			}
+		}
+		//points with non-partition consumers
+		for( Long hopID : part.getExtConsumed() )
+			if( !matTargets.contains(hopID) ) {
+				matTargets.add(hopID);
+				Hop hop = memo.getHopRefs().get(hopID);
+				costs += getSize(hop) * 8 / WRITE_BANDWIDTH;
+			}
+		
+		return costs;
+	}
+	
+	private static double getReadCost(PlanPartition part, CPlanMemoTable memo) {
+		double costs = 0;
+		//get partition input reads (at least read once)
+		for( Long hopID : part.getInputs() ) {
+			Hop hop = memo.getHopRefs().get(hopID);
+			costs += getSize(hop) * 8 / READ_BANDWIDTH;
+		}
+		return costs;
+	}
+	
+	private static double getWriteCost(Collection<Long> R, CPlanMemoTable memo) {
+		double costs = 0;
+		for( Long hopID : R ) {
+			Hop hop = memo.getHopRefs().get(hopID);
+			costs += getSize(hop) * 8 / WRITE_BANDWIDTH;
+		}
+		return costs;
+	}
+	
+	private static double getComputeCost(HashMap<Long, Double> computeCosts, CPlanMemoTable memo) {
+		double costs = 0;
+		for( Entry<Long,Double> e : computeCosts.entrySet() ) {
+			Hop mainInput = memo.getHopRefs()
+				.get(e.getKey()).getInput().get(0);
+			costs += getSize(mainInput) * e.getValue() / COMPUTE_BANDWIDTH;
+		}
+		return costs;
+	}
+	
+	private static long getSize(Hop hop) {
+		return Math.max(hop.getDim1(),1) 
+			* Math.max(hop.getDim2(),1);
+	}
+	
+	//within-partition multi-agg templates
+	private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R)
+	{
+		//create index of plans that reference full aggregates to avoid circular dependencies
+		HashSet<Long> refHops = new HashSet<Long>();
+		for( Entry<Long, List<MemoTableEntry>> e : memo.getPlans().entrySet() )
+			if( !e.getValue().isEmpty() ) {
+				Hop hop = memo.getHopRefs().get(e.getKey());
+				for( Hop c : hop.getInput() )
+					refHops.add(c.getHopID());
+			}
+		
+		//find all full aggregations (the fact that they are in the same partition guarantees 
+		//that they also have common subexpressions, also full aggregations are by def root nodes)
+		ArrayList<Long> fullAggs = new ArrayList<Long>();
+		for( Long hopID : R ) {
+			Hop root = memo.getHopRefs().get(hopID);
+			if( !refHops.contains(hopID) && isMultiAggregateRoot(root) )
+				fullAggs.add(hopID);
+		}
+		if( LOG.isTraceEnabled() ) {
+			LOG.trace("Found within-partition ua(RC) aggregations: " +
+				Arrays.toString(fullAggs.toArray(new Long[0])));
+		}
+		
+		//construct and add multiagg template plans (w/ max 3 aggregations)
+		for( int i=0; i<fullAggs.size(); i+=3 ) {
+			int ito = Math.min(i+3, fullAggs.size());
+			if( ito-i >= 2 ) {
+				MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG,
+					fullAggs.get(i), fullAggs.get(i+1), ((ito-i)==3)?fullAggs.get(i+2):-1, ito-i);
+				if( isValidMultiAggregate(memo, me) ) {
+					for( int j=i; j<ito; j++ ) {
+						memo.add(memo.getHopRefs().get(fullAggs.get(j)), me);
+						if( LOG.isTraceEnabled() )
+							LOG.trace("Added multiagg plan: "+fullAggs.get(j)+" "+me);
+					}
+				}
+				else if( LOG.isTraceEnabled() ) {
+					LOG.trace("Removed invalid multiagg plan: "+me);
+				}
+			}
+		}
+	}
+	
+	//across-partition multi-agg templates with shared reads
+	private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots)
+	{
+		//collect full aggregations as initial set of candidates
+		HashSet<Long> fullAggs = new HashSet<Long>();
+		Hop.resetVisitStatus(roots);
+		for( Hop hop : roots )
+			rCollectFullAggregates(hop, fullAggs);
+		Hop.resetVisitStatus(roots);
+
+		//remove operators with assigned multi-agg plans
+		fullAggs.removeIf(p -> memo.contains(p, TemplateType.MAGG));
+	
+		//check applicability for further analysis
+		if( fullAggs.size() <= 1 )
+			return;
+	
+		if( LOG.isTraceEnabled() ) {
+			LOG.trace("Found across-partition ua(RC) aggregations: " +
+				Arrays.toString(fullAggs.toArray(new Long[0])));
+		}
+		
+		//collect information for all candidates 
+		//(subsumed aggregations, and inputs to fused operators) 
+		List<AggregateInfo> aggInfos = new ArrayList<AggregateInfo>();
+		for( Long hopID : fullAggs ) {
+			Hop aggHop = memo.getHopRefs().get(hopID);
+			AggregateInfo tmp = new AggregateInfo(aggHop);
+			for( int i=0; i<aggHop.getInput().size(); i++ ) {
+				Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i==0 ? 
+					aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i);
+				rExtractAggregateInfo(memo, c, tmp, TemplateType.CELL);
+			}
+			if( tmp._fusedInputs.isEmpty() ) {
+				if( HopRewriteUtils.isMatrixMultiply(aggHop) ) {
+					tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID());
+					tmp.addFusedInput(aggHop.getInput().get(1).getHopID());
+				}
+				else	
+					tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
+			}
+			aggInfos.add(tmp);	
+		}
+		
+		if( LOG.isTraceEnabled() ) {
+			LOG.trace("Extracted across-partition ua(RC) aggregation info: ");
+			for( AggregateInfo info : aggInfos )
+				LOG.trace(info);
+		}
+		
+		//sort aggregations by num dependencies to simplify merging
+		//clusters of aggregations with parallel dependencies
+		aggInfos = aggInfos.stream()
+			.sorted(Comparator.comparing(a -> a._inputAggs.size()))
+			.collect(Collectors.toList());
+		
+		//greedy grouping of multi-agg candidates
+		boolean converged = false;
+		while( !converged ) {
+			AggregateInfo merged = null;
+			for( int i=0; i<aggInfos.size(); i++ ) {
+				AggregateInfo current = aggInfos.get(i);
+				for( int j=i+1; j<aggInfos.size(); j++ ) {
+					AggregateInfo that = aggInfos.get(j);
+					if( current.isMergable(that) ) {
+						merged = current.merge(that);
+						aggInfos.remove(j); j--;
+					}
+				}
+			}
+			converged = (merged == null);
+		}
+		
+		if( LOG.isTraceEnabled() ) {
+			LOG.trace("Merged across-partition ua(RC) aggregation info: ");
+			for( AggregateInfo info : aggInfos )
+				LOG.trace(info);
+		}
+		
+		//construct and add multiagg template plans (w/ max 3 aggregations)
+		for( AggregateInfo info : aggInfos ) {
+			if( info._aggregates.size()<=1 )
+				continue;
+			Long[] aggs = info._aggregates.keySet().toArray(new Long[0]);
+			MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG,
+				aggs[0], aggs[1], (aggs.length>2)?aggs[2]:-1, aggs.length);
+			for( int i=0; i<aggs.length; i++ ) {
+				memo.add(memo.getHopRefs().get(aggs[i]), me);
+				addBestPlan(aggs[i], me);
+				if( LOG.isTraceEnabled() )
+					LOG.trace("Added multiagg* plan: "+aggs[i]+" "+me);
+				
+			}
+		}
+	}
+	
+	private static boolean isMultiAggregateRoot(Hop root) {
+		return (HopRewriteUtils.isAggUnaryOp(root, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX) 
+				&& ((AggUnaryOp)root).getDirection()==Direction.RowCol)
+			|| (root instanceof AggBinaryOp && root.getDim1()==1 && root.getDim2()==1
+				&& HopRewriteUtils.isTransposeOperation(root.getInput().get(0)));
+	}
+	
+	private static boolean isValidMultiAggregate(CPlanMemoTable memo, MemoTableEntry me) {
+		//ensure input consistent sizes (otherwise potential for incorrect results)
+		boolean ret = true;
+		Hop refSize = memo.getHopRefs().get(me.input1).getInput().get(0);
+		for( int i=1; ret && i<3; i++ ) {
+			if( me.isPlanRef(i) )
+				ret &= HopRewriteUtils.isEqualSize(refSize, 
+					memo.getHopRefs().get(me.input(i)).getInput().get(0));
+		}
+		
+		//ensure that aggregates are independent of each other, i.e.,
+		//they to not have potentially transitive parent child references
+		for( int i=0; ret && i<3; i++ ) 
+			if( me.isPlanRef(i) ) {
+				HashSet<Long> probe = new HashSet<Long>();
+				for( int j=0; j<3; j++ )
+					if( i != j )
+						probe.add(me.input(j));
+				ret &= rCheckMultiAggregate(memo.getHopRefs().get(me.input(i)), probe);
+			}
+		return ret;
+	}
+	
+	private static boolean rCheckMultiAggregate(Hop current, HashSet<Long> probe) {
+		boolean ret = true;
+		for( Hop c : current.getInput() )
+			ret &= rCheckMultiAggregate(c, probe);
+		ret &= !probe.contains(current.getHopID());
+		return ret;
+	}
+	
+	private static void rCollectFullAggregates(Hop current, HashSet<Long> aggs) {
+		if( current.isVisited() )
+			return;
+		
+		//collect all applicable full aggregations per read
+		if( isMultiAggregateRoot(current) )
+			aggs.add(current.getHopID());
+		
+		//recursively process children
+		for( Hop c : current.getInput() )
+			rCollectFullAggregates(c, aggs);
+		
+		current.setVisited();
+	}
+	
+	private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop current, AggregateInfo aggInfo, TemplateType type) {
+		//collect input aggregates (dependents)
+		if( isMultiAggregateRoot(current) )
+			aggInfo.addInputAggregate(current.getHopID());
+		
+		//recursively process children
+		MemoTableEntry me = (type!=null) ? memo.getBest(current.getHopID()) : null;
+		for( int i=0; i<current.getInput().size(); i++ ) {
+			Hop c = current.getInput().get(i);
+			if( me != null && me.isPlanRef(i) )
+				rExtractAggregateInfo(memo, c, aggInfo, type);
+			else {
+				if( type != null && c.getDataType().isMatrix()  ) //add fused input
+					aggInfo.addFusedInput(c.getHopID());
+				rExtractAggregateInfo(memo, c, aggInfo, null);
+			}
+		}
+	}
+	
+	private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
+		//consider all aggregations other than root operation
+		MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
+		boolean ret = true;
+		for(int i=0; i<3; i++)
+			if( me.isPlanRef(i) )
+				ret &= rIsRowTemplateWithoutAgg(memo, 
+					current.getInput().get(i), visited);
+		return ret;
+	}
+	
+	private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
+		if( visited.contains(current.getHopID()) )
+			return true;
+		
+		boolean ret = true;
+		MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
+		for(int i=0; i<3; i++)
+			if( me.isPlanRef(i) )
+				ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited);
+		ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp);
+		
+		visited.add(current.getHopID());
+		return ret;
+	}
+	
+	private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, 
+		PlanPartition part, InterestingPoint[] matPoints, boolean[] plan) 
+	{
+		//memoization (not via hops because in middle of dag)
+		if( visited.contains(current.getHopID()) )
+			return;
+		
+		//remove memo table entries if necessary
+		long hopID = current.getHopID();
+		if( part.getPartition().contains(hopID) && memo.contains(hopID) ) {
+			Iterator<MemoTableEntry> iter = memo.get(hopID).iterator();
+			while( iter.hasNext() ) {
+				MemoTableEntry me = iter.next();
+				if( !hasNoRefToMatPoint(hopID, me, matPoints, plan) && me.type!=TemplateType.OUTER ) {
+					iter.remove();
+					if( LOG.isTraceEnabled() )
+						LOG.trace("Removed memo table entry: "+me);
+				}
+			}
+		}
+		
+		//process children recursively
+		for( Hop c : current.getInput() )
+			rPruneSuboptimalPlans(memo, c, visited, part, matPoints, plan);
+		
+		visited.add(current.getHopID());
+	}
+	
+	private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, PlanPartition part, boolean[] plan) {
+		//memoization (not via hops because in middle of dag)
+		if( visited.contains(current.getHopID()) )
+			return;
+		
+		//process children recursively
+		for( Hop c : current.getInput() )
+			rPruneInvalidPlans(memo, c, visited, part, plan);
+		
+		//find invalid row aggregate leaf nodes (see TemplateRow.open) w/o matrix inputs, 
+		//i.e., plans that become invalid after the previous pruning step
+		long hopID = current.getHopID();
+		if( part.getPartition().contains(hopID) && memo.contains(hopID, TemplateType.ROW) ) {
+			for( MemoTableEntry me : memo.get(hopID) ) {
+				if( me.type==TemplateType.ROW ) {
+					//convert leaf node with pure vector inputs
+					if( !me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current) ) {
+						me.type = TemplateType.CELL;
+						if( LOG.isTraceEnabled() )
+							LOG.trace("Converted leaf memo table entry from row to cell: "+me);
+					}
+					
+					//convert inner node without row template input
+					if( me.hasPlanRef() && !ROW_TPL.open(current) ) {
+						boolean hasRowInput = false;
+						for( int i=0; i<3; i++ )
+							if( me.isPlanRef(i) )
+								hasRowInput |= memo.contains(me.input(i), TemplateType.ROW);
+						if( !hasRowInput ) {
+							me.type = TemplateType.CELL;
+							if( LOG.isTraceEnabled() )
+								LOG.trace("Converted inner memo table entry from row to cell: "+me);	
+						}
+					}
+					
+				}
+			}
+		}
+		
+		visited.add(current.getHopID());		
+	}
+	
+	private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateType currentType, HashSet<Long> partition) 
+	{	
+		if( isVisited(current.getHopID(), currentType) 
+			|| !partition.contains(current.getHopID()) )
+			return;
+		
+		//step 1: prune subsumed plans of same type
+		if( memo.contains(current.getHopID()) ) {
+			HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
+			List<MemoTableEntry> hopP = memo.get(current.getHopID());
+			for( MemoTableEntry e1 : hopP )
+				for( MemoTableEntry e2 : hopP )
+					if( e1 != e2 && e1.subsumes(e2) )
+						rmSet.add(e2);
+			memo.remove(current, rmSet);
+		}
+		
+		//step 2: select plan for current path
+		MemoTableEntry best = null;
+		if( memo.contains(current.getHopID()) ) {
+			if( currentType == null ) {
+				best = memo.get(current.getHopID()).stream()
+					.filter(p -> isValid(p, current))
+					.min(BASE_COMPARE).orElse(null);
+			}
+			else {
+				_typedCompare.setType(currentType);
+				best = memo.get(current.getHopID()).stream()
+					.filter(p -> p.type==currentType || p.type==TemplateType.CELL)
+					.min(_typedCompare).orElse(null);
+			}
+			addBestPlan(current.getHopID(), best);
+		}
+		
+		//step 3: recursively process children
+		for( int i=0; i< current.getInput().size(); i++ ) {
+			TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null;
+			rSelectPlansFuseAll(memo, current.getInput().get(i), pref, partition);
+		}
+		
+		setVisited(current.getHopID(), currentType);
+	}
+	
+	/////////////////////////////////////////////////////////
+	// Cost model fused operators w/ materialization points
+	//////////
+	
+	private static double getPlanCost(CPlanMemoTable memo, PlanPartition part, 
+		InterestingPoint[] matPoints,boolean[] plan, HashMap<Long, Double> computeCosts) 
+	{
+		//high level heuristic: every hop or fused operator has the following cost: 
+		//WRITE + max(COMPUTE, READ), where WRITE costs are given by the output size, 
+		//READ costs by the input sizes, and COMPUTE by operation specific FLOP
+		//counts times number of cells of main input, disregarding sparsity for now.
+		
+		HashSet<VisitMarkCost> visited = new HashSet<>();
+		double costs = 0;
+		for( Long hopID : part.getRoots() ) {
+			costs += rGetPlanCosts(memo, memo.getHopRefs().get(hopID), 
+				visited, part, matPoints, plan, computeCosts, null, null);
+		}
+		return costs;
+	}
+	
+	private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, HashSet<VisitMarkCost> visited, 
+		PlanPartition part, InterestingPoint[] matPoints, boolean[] plan, HashMap<Long, Double> computeCosts, 
+		CostVector costsCurrent, TemplateType currentType) 
+	{
+		//memoization per hop id and cost vector to account for redundant
+		//computation without double counting materialized results or compute
+		//costs of complex operation DAGs within a single fused operator
+		VisitMarkCost tag = new VisitMarkCost(current.getHopID(), 
+			(costsCurrent==null || currentType==TemplateType.MAGG)?0:costsCurrent.ID);
+		if( visited.contains(tag) )
+			return 0;
+		visited.add(tag);
+		
+		//open template if necessary, including memoization
+		//under awareness of current plan choice
+		MemoTableEntry best = null;
+		boolean opened = false;
+		if( memo.contains(current.getHopID()) ) {
+			//note: this is the inner loop of plan enumeration and hence, we do not 
+			//use streams, lambda expressions, etc to avoid unnecessary overhead
+			long hopID = current.getHopID();
+			if( currentType == null ) {
+				for( MemoTableEntry me : memo.get(hopID) )
+					best = isValid(me, current) 
+						&& hasNoRefToMatPoint(hopID, me, matPoints, plan) 
+						&& BasicPlanComparator.icompare(me, best)<0 ? me : best;
+				opened = true;
+			}
+			else {
+				for( MemoTableEntry me : memo.get(hopID) )
+					best = (me.type == currentType || me.type==TemplateType.CELL)
+						&& hasNoRefToMatPoint(hopID, me, matPoints, plan) 
+						&& TypedPlanComparator.icompare(me, best, currentType)<0 ? me : best;
+			}
+		}
+		
+		//create new cost vector if opened, initialized with write costs
+		CostVector costVect = !opened ? costsCurrent : new CostVector(getSize(current));
+		double costs = 0;
+		
+		//add other roots for multi-agg template to account for shared costs
+		if( opened && best != null && best.type == TemplateType.MAGG ) {
+			//account costs to first multi-agg root 
+			if( best.input1 == current.getHopID() )
+				for( int i=1; i<3; i++ ) {
+					if( !best.isPlanRef(i) ) continue;
+					costs += rGetPlanCosts(memo, memo.getHopRefs().get(best.input(i)), visited, 
+						part, matPoints, plan, computeCosts, costVect, TemplateType.MAGG);
+				}
+			//skip other multi-agg roots
+			else
+				return 0;
+		}
+		
+		//add compute costs of current operator to costs vector
+		if( part.getPartition().contains(current.getHopID()) )
+			costVect.computeCosts += computeCosts.get(current.getHopID());
+		
+		//process children recursively
+		for( int i=0; i< current.getInput().size(); i++ ) {
+			Hop c = current.getInput().get(i);
+			if( best!=null && best.isPlanRef(i) )
+				costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, costVect, best.type);
+			else if( best!=null && isImplicitlyFused(current, i, best.type) )
+				costVect.addInputSize(c.getInput().get(0).getHopID(), getSize(c));
+			else { //include children and I/O costs
+				costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, null, null);
+				if( costVect != null && c.getDataType().isMatrix() )
+					costVect.addInputSize(c.getHopID(), getSize(c));
+			}
+		}
+		
+		//add costs for opened fused operator
+		if( part.getPartition().contains(current.getHopID()) ) {
+			if( opened ) {
+				if( LOG.isTraceEnabled() ) {
+					String type = (best !=null) ? best.type.name() : "HOP";
+					LOG.trace("Cost vector ("+type+" "+current.getHopID()+"): "+costVect);
+				}
+				double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH //time for output write
+					+ Math.max(costVect.getSumInputSizes() * 8 / READ_BANDWIDTH,
+					costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH);
+				//sparsity correction for outer-product template (and sparse-safe cell)
+				if( best != null && best.type == TemplateType.OUTER ) {
+					Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID());
+					tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST;
+				}
+				costs += tmpCosts;
+			}
+			//add costs for non-partition read in the middle of fused operator
+			else if( part.getExtConsumed().contains(current.getHopID()) ) {
+				costs += rGetPlanCosts(memo, current, visited,
+					part, matPoints, plan, computeCosts, null, null);
+			}
+		}
+		
+		//sanity check non-negative costs
+		if( costs < 0 || Double.isNaN(costs) || Double.isInfinite(costs) )
+			throw new RuntimeException("Wrong cost estimate: "+costs);
+		
+		return costs;
+	}
+	
+	private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts) 
+	{
+		if( computeCosts.containsKey(current.getHopID()) 
+			|| !partition.contains(current.getHopID()) )
+			return;
+		
+		//recursively process children
+		for( Hop c : current.getInput() )
+			rGetComputeCosts(c, partition, computeCosts);
+		
+		//get costs for given hop
+		double costs = 1;
+		if( current instanceof UnaryOp ) {
+			switch( ((UnaryOp)current).getOp() ) {
+				case ABS:   
+				case ROUND:
+				case CEIL:
+				case FLOOR:
+				case SIGN:
+				case SELP:    costs = 1; break; 
+				case SPROP:
+				case SQRT:    costs = 2; break;
+				case EXP:     costs = 18; break;
+				case SIGMOID: costs = 21; break;
+				case LOG:    
+				case LOG_NZ:  costs = 32; break;
+				case NCOL:
+				case NROW:
+				case PRINT:
+				case CAST_AS_BOOLEAN:
+				case CAST_AS_DOUBLE:
+				case CAST_AS_INT:
+				case CAST_AS_MATRIX:
+				case CAST_AS_SCALAR: costs = 1; break;
+				case SIN:     costs = 18; break;
+				case COS:     costs = 22; break;
+				case TAN:     costs = 42; break;
+				case ASIN:    costs = 93; break;
+				case ACOS:    costs = 103; break;
+				case ATAN:    costs = 40; break;
+				case CUMSUM:
+				case CUMMIN:
+				case CUMMAX:
+				case CUMPROD: costs = 1; break;
+				default:
+					LOG.warn("Cost model not "
+						+ "implemented yet for: "+((UnaryOp)current).getOp());
+			}
+		}
+		else if( current instanceof BinaryOp ) {
+			switch( ((BinaryOp)current).getOp() ) {
+				case MULT: 
+				case PLUS:
+				case MINUS:
+				case MIN:
+				case MAX: 
+				case AND:
+				case OR:
+				case EQUAL:
+				case NOTEQUAL:
+				case LESS:
+				case LESSEQUAL:
+				case GREATER:
+				case GREATEREQUAL: 
+				case CBIND:
+				case RBIND:   costs = 1; break;
+				case INTDIV:  costs = 6; break;
+				case MODULUS: costs = 8; break;
+				case DIV:     costs = 22; break;
+				case LOG:
+				case LOG_NZ:  costs = 32; break;
+				case POW:     costs = (HopRewriteUtils.isLiteralOfValue(
+						current.getInput().get(1), 2) ? 1 : 16); break;
+				case MINUS_NZ:
+				case MINUS1_MULT: costs = 2; break;
+				case CENTRALMOMENT:
+					int type = (int) (current.getInput().get(1) instanceof LiteralOp ? 
+						HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
+					switch( type ) {
+						case 0: costs = 1; break; //count
+						case 1: costs = 8; break; //mean
+						case 2: costs = 16; break; //cm2
+						case 3: costs = 31; break; //cm3
+						case 4: costs = 51; break; //cm4
+						case 5: costs = 16; break; //variance
+					}
+					break;
+				case COVARIANCE: costs = 23; break;
+				default:
+					LOG.warn("Cost model not "
+						+ "implemented yet for: "+((BinaryOp)current).getOp());
+			}
+		}
+		else if( current instanceof TernaryOp ) {
+			switch( ((TernaryOp)current).getOp() ) {
+				case PLUS_MULT: 
+				case MINUS_MULT: costs = 2; break;
+				case CTABLE:     costs = 3; break;
+				case CENTRALMOMENT:
+					int type = (int) (current.getInput().get(1) instanceof LiteralOp ? 
+						HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
+					switch( type ) {
+						case 0: costs = 2; break; //count
+						case 1: costs = 9; break; //mean
+						case 2: costs = 17; break; //cm2
+						case 3: costs = 32; break; //cm3
+						case 4: costs = 52; break; //cm4
+						case 5: costs = 17; break; //variance
+					}
+					break;
+				case COVARIANCE: costs = 23; break;
+				default:
+					LOG.warn("Cost model not "
+						+ "implemented yet for: "+((TernaryOp)current).getOp());
+			}
+		}
+		else if( current instanceof ParameterizedBuiltinOp ) {
+			costs = 1;
+		}
+		else if( current instanceof IndexingOp ) {
+			costs = 1;
+		}
+		else if( current instanceof ReorgOp ) {
+			costs = 1;
+		}
+		else if( current instanceof AggBinaryOp ) {
+			//outer product template
+			if( HopRewriteUtils.isOuterProductLikeMM(current) )
+				costs = 2 * current.getInput().get(0).getDim2();
+			//row template w/ matrix-vector or matrix-matrix
+			else
+				costs = 2 * current .getDim2();
+		}
+		else if( current instanceof AggUnaryOp) {
+			switch(((AggUnaryOp)current).getOp()) {
+			case SUM:    costs = 4; break; 
+			case SUM_SQ: costs = 5; break;
+			case MIN:
+			case MAX:    costs = 1; break;
+			default:
+				LOG.warn("Cost model not "
+					+ "implemented yet for: "+((AggUnaryOp)current).getOp());			
+			}
+		}
+		
+		computeCosts.put(current.getHopID(), costs);
+	}
+	
+	private static boolean hasNoRefToMatPoint(long hopID, 
+			MemoTableEntry me, InterestingPoint[] M, boolean[] plan) {
+		return !InterestingPoint.isMatPoint(M, hopID, me, plan);
+	}
+	
+	private static boolean isImplicitlyFused(Hop hop, int index, TemplateType type) {
+		return type == TemplateType.ROW
+			&& HopRewriteUtils.isMatrixMultiply(hop) && index==0 
+			&& HopRewriteUtils.isTransposeOperation(hop.getInput().get(index)); 
+	}
+	
+	private static class CostVector {
+		public final long ID;
+		public final double outSize; 
+		public double computeCosts = 0;
+		public final HashMap<Long, Double> inSizes = new HashMap<Long, Double>();
+		
+		public CostVector(double outputSize) {
+			ID = COST_ID.getNextID();
+			outSize = outputSize;
+		}
+		public void addInputSize(long hopID, double inputSize) {
+			//ensures that input sizes are not double counted
+			inSizes.put(hopID, inputSize);
+		}
+		public double getSumInputSizes() {
+			return inSizes.values().stream()
+				.mapToDouble(d -> d.doubleValue()).sum();
+		}
+		public double getMaxInputSize() {
+			return inSizes.values().stream()
+				.mapToDouble(d -> d.doubleValue()).max().orElse(0);
+		}
+		public long getMaxInputSizeHopID() {
+			long id = -1; double max = 0;
+			for( Entry<Long,Double> e : inSizes.entrySet() )
+				if( max < e.getValue() ) {
+					id = e.getKey();
+					max = e.getValue();
+				}
+			return id;
+		}
+		@Override
+		public String toString() {
+			return "["+outSize+", "+computeCosts+", {"
+				+Arrays.toString(inSizes.keySet().toArray(new Long[0]))+", "
+				+Arrays.toString(inSizes.values().toArray(new Double[0]))+"}]";
+		}
+	}
+	
+	private static class StaticCosts {
+		public final HashMap<Long, Double> _computeCosts;
+		public final double _compute;
+		public final double _read;
+		public final double _write;
+		
+		public StaticCosts(HashMap<Long,Double> allComputeCosts, double computeCost, double readCost, double writeCost) {
+			_computeCosts = allComputeCosts;
+			_compute = computeCost;
+			_read = readCost;
+			_write = writeCost;
+		}
+	}
+	
+	private static class AggregateInfo {
+		public final HashMap<Long,Hop> _aggregates;
+		public final HashSet<Long> _inputAggs = new HashSet<Long>();
+		public final HashSet<Long> _fusedInputs = new HashSet<Long>();
+		public AggregateInfo(Hop aggregate) {
+			_aggregates = new HashMap<Long, Hop>();
+			_aggregates.put(aggregate.getHopID(), aggregate);
+		}
+		public void addInputAggregate(long hopID) {
+			_inputAggs.add(hopID);
+		}
+		public void addFusedInput(long hopID) {
+			_fusedInputs.add(hopID);
+		}
+		public boolean isMergable(AggregateInfo that) {
+			//check independence
+			boolean ret = _aggregates.size()<3 
+				&& _aggregates.size()+that._aggregates.size()<=3;
+			for( Long hopID : that._aggregates.keySet() )
+				ret &= !_inputAggs.contains(hopID);
+			for( Long hopID : _aggregates.keySet() )
+				ret &= !that._inputAggs.contains(hopID);
+			//check partial shared reads
+			ret &= !CollectionUtils.intersection(
+				_fusedInputs, that._fusedInputs).isEmpty();
+			//check consistent sizes (result correctness)
+			Hop in1 = _aggregates.values().iterator().next();
+			Hop in2 = that._aggregates.values().iterator().next();
+			return ret && HopRewriteUtils.isEqualSize(
+				in1.getInput().get(HopRewriteUtils.isMatrixMultiply(in1)?1:0),
+				in2.getInput().get(HopRewriteUtils.isMatrixMultiply(in2)?1:0));
+		}
+		public AggregateInfo merge(AggregateInfo that) {
+			_aggregates.putAll(that._aggregates);
+			_inputAggs.addAll(that._inputAggs);
+			_fusedInputs.addAll(that._fusedInputs);
+			return this;
+		}
+		@Override
+		public String toString() {
+			return "["+Arrays.toString(_aggregates.keySet().toArray(new Long[0]))+": "
+				+"{"+Arrays.toString(_inputAggs.toArray(new Long[0]))+"}," 
+				+"{"+Arrays.toString(_fusedInputs.toArray(new Long[0]))+"}]"; 
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java
new file mode 100644
index 0000000..759a903
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Map.Entry;
+import java.util.HashSet;
+import java.util.List;
+
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
+import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
+
+/**
+ * This plan selection heuristic aims for fusion without any redundant 
+ * computation, which, however, potentially leads to more materialized 
+ * intermediates than the fuse all heuristic.
+ * <p>
+ * NOTE: This heuristic is essentially the same as FuseAll, except that 
+ * any plans that refer to a hop with multiple consumers are removed in 
+ * a pre-processing step.
+ * 
+ */
+public class PlanSelectionFuseNoRedundancy extends PlanSelection
+{	
+	@Override
+	public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
+		//pruning and collection pass
+		for( Hop hop : roots )
+			rSelectPlans(memo, hop, null);
+		
+		//take all distinct best plans
+		for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() )
+			memo.setDistinct(e.getKey(), e.getValue());
+	}
+	
+	private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType) 
+	{	
+		if( isVisited(current.getHopID(), currentType) )
+			return;
+		
+		//step 0: remove plans that refer to a common partial plan
+		if( memo.contains(current.getHopID()) ) {
+			HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
+			List<MemoTableEntry> hopP = memo.get(current.getHopID());
+			for( MemoTableEntry e1 : hopP )
+				for( int i=0; i<3; i++ )
+					if( e1.isPlanRef(i) && current.getInput().get(i).getParent().size()>1 )
+						rmSet.add(e1); //remove references to hops w/ multiple consumers
+			memo.remove(current, rmSet);
+		}
+		
+		//step 1: prune subsumed plans of same type
+		if( memo.contains(current.getHopID()) ) {
+			HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
+			List<MemoTableEntry> hopP = memo.get(current.getHopID());
+			for( MemoTableEntry e1 : hopP )
+				for( MemoTableEntry e2 : hopP )
+					if( e1 != e2 && e1.subsumes(e2) )
+						rmSet.add(e2);
+			memo.remove(current, rmSet);
+		}
+		
+		//step 2: select plan for current path
+		MemoTableEntry best = null;
+		if( memo.contains(current.getHopID()) ) {
+			if( currentType == null ) {
+				best = memo.get(current.getHopID()).stream()
+					.filter(p -> isValid(p, current))
+					.min(new BasicPlanComparator()).orElse(null);
+			}
+			else {
+				best = memo.get(current.getHopID()).stream()
+					.filter(p -> p.type==currentType || p.type==TemplateType.CELL)
+					.min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs()))
+					.orElse(null);
+			}
+			addBestPlan(current.getHopID(), best);
+		}
+		
+		//step 3: recursively process children
+		for( int i=0; i< current.getInput().size(); i++ ) {
+			TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null;
+			rSelectPlans(memo, current.getInput().get(i), pref);
+		}
+		
+		setVisited(current.getHopID(), currentType);
+	}	
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java b/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java
new file mode 100644
index 0000000..de1ed92
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java
@@ -0,0 +1,398 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
+import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysml.hops.codegen.opt.PlanSelection.VisitMarkCost;
+
+/**
+ *  
+ */
+public class ReachabilityGraph 
+{
+	private HashMap<Pair<Long,Long>,NodeLink> _matPoints = null;
+	private NodeLink _root = null;
+	
+	private InterestingPoint[] _searchSpace;
+	private CutSet[] _cutSets;
+	
+	public ReachabilityGraph(PlanPartition part, CPlanMemoTable memo) {
+		//create repository of materialization points
+		_matPoints = new HashMap<>();
+		for( InterestingPoint p : part.getMatPointsExt() )
+			_matPoints.put(Pair.of(p._fromHopID, p._toHopID), new NodeLink(p));
+		
+		//create reachability graph
+		_root = new NodeLink(null);
+		HashSet<VisitMarkCost> visited = new HashSet<>();
+		for( Long hopID : part.getRoots() ) {
+			Hop rootHop = memo.getHopRefs().get(hopID);
+			addInputNodeLinks(rootHop, _root, part, memo, visited);
+		}
+		
+		//create candidate cutsets 
+		List<NodeLink> tmpCS = _matPoints.values().stream()
+			.filter(p -> p._inputs.size() > 0 && p._p != null)
+			.sorted().collect(Collectors.toList());
+		
+		//short-cut for partitions without cutsets
+		if( tmpCS.isEmpty() ) {
+			_cutSets = new CutSet[0];
+			_searchSpace = part.getMatPointsExt();
+			return;
+		}
+		
+		//create composite cutsets 
+		ArrayList<ArrayList<NodeLink>> candCS = new ArrayList<>();
+		ArrayList<NodeLink> current = new ArrayList<>();
+		for( NodeLink node : tmpCS ) {
+			if( current.isEmpty() )
+				current.add(node);
+			else if( current.get(0).equals(node) )
+				current.add(node);
+			else {
+				candCS.add(current);
+				current = new ArrayList<>();
+				current.add(node);
+			}
+		}
+		if( !current.isEmpty() )
+			candCS.add(current);
+		
+		//evaluate cutsets (single, and duplicate pairs)
+		ArrayList<ArrayList<NodeLink>> remain = new ArrayList<>();
+		ArrayList<Pair<CutSet,Double>> cutSets = evaluateCutSets(candCS, remain);
+		if( !remain.isEmpty() && remain.size() < 5 ) {
+			//second chance: for pairs for remaining candidates
+			ArrayList<ArrayList<NodeLink>> candCS2 = new ArrayList<>();
+			for( int i=0; i<remain.size()-1; i++)
+				for( int j=i+1; j<remain.size(); j++) {
+					ArrayList<NodeLink> tmp = new ArrayList<>();
+					tmp.addAll(remain.get(i));
+					tmp.addAll(remain.get(j));
+					candCS2.add(tmp);
+				}
+			ArrayList<Pair<CutSet,Double>> cutSets2 = evaluateCutSets(candCS2, remain);
+			//ensure constructed cutsets are disjoint
+			HashSet<InterestingPoint> testDisjoint = new HashSet<>();
+			for( Pair<CutSet,Double> cs : cutSets2 ) {
+				if( !CollectionUtils.containsAny(testDisjoint, Arrays.asList(cs.getLeft().cut)) ) {
+					cutSets.add(cs);
+					CollectionUtils.addAll(testDisjoint, cs.getLeft().cut);
+				}
+			}
+		}
+		
+		//sort and linearize search space according to scores
+		_cutSets = cutSets.stream()
+				.sorted(Comparator.comparing(p -> p.getRight()))
+				.map(p -> p.getLeft()).toArray(CutSet[]::new);
+	
+		HashMap<InterestingPoint, Integer> probe = new HashMap<>();
+		ArrayList<InterestingPoint> lsearchSpace = new ArrayList<>();
+		for( CutSet cs : _cutSets ) {
+			cs.updatePos(lsearchSpace.size());
+			cs.updatePartitions(probe);
+			CollectionUtils.addAll(lsearchSpace, cs.cut);
+			for( InterestingPoint p: cs.cut )
+				probe.put(p, probe.size()-1);
+		}
+		for( InterestingPoint p : part.getMatPointsExt() )
+			if( !probe.containsKey(p) ) {
+				lsearchSpace.add(p);
+				probe.put(p, probe.size()-1);
+			}
+		_searchSpace = lsearchSpace.toArray(new InterestingPoint[0]);
+		
+		//materialize partition indices
+		for( CutSet cs : _cutSets ) {
+			cs.updatePartitionIndexes(probe);
+			cs.finalizePartition();
+		}
+		
+		//final sanity check of interesting points
+		if( _searchSpace.length != part.getMatPointsExt().length )
+			throw new RuntimeException("Corrupt linearized search space: " +
+				_searchSpace.length+" vs "+part.getMatPointsExt().length);
+	}
+	
+	public InterestingPoint[] getSortedSearchSpace() {
+		return _searchSpace;
+	}
+
+	public boolean isCutSet(boolean[] plan) {
+		for( CutSet cs : _cutSets )
+			if( isCutSet(cs, plan) )
+				return true;
+		return false;
+	}
+	
+	public boolean isCutSet(CutSet cs, boolean[] plan) {
+		boolean ret = true;
+		for(int i=0; i<cs.posCut.length && ret; i++)
+			ret &= plan[cs.posCut[i]];
+		return ret;
+	}
+	
+	public CutSet getCutSet(boolean[] plan) {
+		for( CutSet cs : _cutSets )
+			if( isCutSet(cs, plan) )
+				return cs;
+		throw new RuntimeException("No valid cut set found.");
+	}
+
+	public long getNumSkipPlans(boolean[] plan) {
+		for( CutSet cs : _cutSets )
+			if( isCutSet(cs, plan) ) {
+				int pos = cs.posCut[cs.posCut.length-1];				
+				return (long) Math.pow(2, plan.length-pos-1);
+			}
+		throw new RuntimeException("Failed to compute "
+			+ "number of skip plans for plan without cutset.");
+	}
+
+
+	public SubProblem[] getSubproblems(boolean[] plan) {
+		CutSet cs = getCutSet(plan);
+		return new SubProblem[] {
+				new SubProblem(cs.cut.length, cs.posLeft, cs.left), 
+				new SubProblem(cs.cut.length, cs.posRight, cs.right)};
+	}
+	
+	@Override
+	public String toString() {
+		return "ReachabilityGraph("+_matPoints.size()+"):\n"
+			+ _root.explain(new HashSet<>());
+	}
+	
+	private void addInputNodeLinks(Hop current, NodeLink parent, PlanPartition part, 
+		CPlanMemoTable memo, HashSet<VisitMarkCost> visited) 
+	{
+		if( visited.contains(new VisitMarkCost(current.getHopID(), parent._ID)) )
+			return;
+		
+		//process children
+		for( Hop in : current.getInput() ) {
+			if( InterestingPoint.isMatPoint(part.getMatPointsExt(), current.getHopID(), in.getHopID()) ) {
+				NodeLink tmp = _matPoints.get(Pair.of(current.getHopID(), in.getHopID()));
+				parent.addInput(tmp);
+				addInputNodeLinks(in, tmp, part, memo, visited);
+			}
+			else
+				addInputNodeLinks(in, parent, part, memo, visited);
+		}
+		
+		visited.add(new VisitMarkCost(current.getHopID(), parent._ID));
+	}
+	
+	private void rCollectInputs(NodeLink current, HashSet<NodeLink> probe, HashSet<NodeLink> inputs) {
+		for( NodeLink c : current._inputs ) 
+			if( !probe.contains(c) ) {
+				rCollectInputs(c, probe, inputs);
+				inputs.add(c);
+			}
+	}
+	
+	private ArrayList<Pair<CutSet,Double>> evaluateCutSets(ArrayList<ArrayList<NodeLink>> candCS, ArrayList<ArrayList<NodeLink>> remain) {
+		ArrayList<Pair<CutSet,Double>> cutSets = new ArrayList<>();
+		
+		for( ArrayList<NodeLink> cand : candCS ) {
+			HashSet<NodeLink> probe = new HashSet<>(cand);
+			
+			//determine subproblems for cutset candidates
+			HashSet<NodeLink> part1 = new HashSet<>();
+			rCollectInputs(_root, probe, part1);
+			HashSet<NodeLink> part2 = new HashSet<>();
+			for( NodeLink rNode : cand )
+				rCollectInputs(rNode, probe, part2);
+			
+			//select, score and create cutsets
+			if( !CollectionUtils.containsAny(part1, part2) 
+				&& !part1.isEmpty() && !part2.isEmpty()) {
+				//score cutsets (smaller is better)
+				double base = Math.pow(2, _matPoints.size());
+				double numComb = Math.pow(2, cand.size());
+				double score = (numComb-1)/numComb * base
+					+ 1/numComb * Math.pow(2, part1.size())
+					+ 1/numComb * Math.pow(2, part2.size());
+				
+				//construct cutset
+				cutSets.add(Pair.of(new CutSet(
+					cand.stream().map(p->p._p).toArray(InterestingPoint[]::new), 
+					part1.stream().map(p->p._p).toArray(InterestingPoint[]::new), 
+					part2.stream().map(p->p._p).toArray(InterestingPoint[]::new)), score));
+			}
+			else {
+				remain.add(cand);
+			}
+		}
+		
+		return cutSets;
+	}
+		
+	public static class SubProblem {
+		public int offset;
+		public int[] freePos;
+		public InterestingPoint[] freeMat;
+		
+		public SubProblem(int off, int[] pos, InterestingPoint[] mat) {
+			offset = off;
+			freePos = pos;
+			freeMat = mat;
+		}
+	}
+	
+	public static class CutSet {
+		public InterestingPoint[] cut;
+		public InterestingPoint[] left;
+		public InterestingPoint[] right;
+		public int[] posCut;
+		public int[] posLeft;
+		public int[] posRight;
+		
+		public CutSet(InterestingPoint[] cutPoints, 
+				InterestingPoint[] l, InterestingPoint[] r) {
+			cut = cutPoints;
+			left = l;
+			right = r;
+		}
+		
+		public void updatePos(int index) {
+			posCut = new int[cut.length];
+			for(int i=0; i<posCut.length; i++)
+				posCut[i] = index + i;
+		}
+		
+		public void updatePartitions(HashMap<InterestingPoint,Integer> blacklist) {
+			left = Arrays.stream(left).filter(p -> !blacklist.containsKey(p))
+				.toArray(InterestingPoint[]::new);
+			right = Arrays.stream(right).filter(p -> !blacklist.containsKey(p))
+				.toArray(InterestingPoint[]::new);
+		}
+		
+		public void updatePartitionIndexes(HashMap<InterestingPoint,Integer> probe) {
+			posLeft = new int[left.length];
+			for(int i=0; i<left.length; i++)
+				posLeft[i] = probe.get(left[i]);
+			posRight = new int[right.length];
+			for(int i=0; i<right.length; i++)
+				posRight[i] = probe.get(right[i]);
+		}
+		
+		public void finalizePartition() {
+			left = (InterestingPoint[]) ArrayUtils.addAll(cut, left);
+			right = (InterestingPoint[]) ArrayUtils.addAll(cut, right);
+		}
+		
+		@Override
+		public String toString() {
+			return "Cut : "+Arrays.toString(cut);
+		}
+	}
+		
+	private static class NodeLink implements Comparable<NodeLink>
+	{
+		private static final IDSequence _seqID = new IDSequence();
+		
+		private ArrayList<NodeLink> _inputs = new ArrayList<>();
+		private long _ID;
+		private InterestingPoint _p;
+		
+		public NodeLink(InterestingPoint p) {
+			_ID = _seqID.getNextID();
+			_p = p;
+		} 
+		
+		public void addInput(NodeLink in) {
+			_inputs.add(in);
+		}
+		
+		@Override
+		public boolean equals(Object o) {
+			if( !(o instanceof NodeLink) )
+				return false;
+			NodeLink that = (NodeLink) o;
+			boolean ret = (_inputs.size() == that._inputs.size());
+			for( int i=0; i<_inputs.size() && ret; i++ )
+				ret &= (_inputs.get(i)._ID == that._inputs.get(i)._ID);
+			return ret;
+		}
+		
+		@Override
+		public int compareTo(NodeLink that) {
+			if( _inputs.size() > that._inputs.size() )
+				return -1;
+			else if( _inputs.size() < that._inputs.size() )
+				return 1;
+			for( int i=0; i<_inputs.size(); i++ ) {
+				int comp = Long.compare(_inputs.get(i)._ID, 
+					that._inputs.get(i)._ID);
+				if( comp != 0 )
+					return comp;
+			}
+			return 0;
+		}
+		
+		@Override
+		public String toString() {
+			StringBuilder inputs = new StringBuilder();
+			for(NodeLink in : _inputs) {
+				if( inputs.length() > 0 )
+					inputs.append(",");
+				inputs.append(in._ID);
+			}
+			return _ID+" ("+inputs.toString()+") "+((_p!=null)?_p:"null");
+		}
+		
+		private String explain(HashSet<Long> visited) {
+			if( visited.contains(_ID) )
+				return "";
+			//add children
+			StringBuilder sb = new StringBuilder();
+			StringBuilder inputs = new StringBuilder();
+			for(NodeLink in : _inputs) {
+				String tmp = in.explain(visited);
+				if( !tmp.isEmpty() )
+					sb.append(tmp + "\n");
+				if( inputs.length() > 0 )
+					inputs.append(",");
+				inputs.append(in._ID);
+			}
+			//add node itself
+			sb.append(_ID+" ("+inputs+") "+((_p!=null)?_p:"null"));
+			visited.add(_ID);
+			
+			return sb.toString();
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
index edbcdf9..4078060 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.codegen.template;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -36,6 +37,8 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.codegen.SpoofCompiler;
+import org.apache.sysml.hops.codegen.opt.InterestingPoint;
+import org.apache.sysml.hops.codegen.opt.PlanSelection;
 import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
 import org.apache.sysml.runtime.util.UtilFunctions;
 
@@ -53,6 +56,18 @@ public class CPlanMemoTable
 		_plansBlacklist = new HashSet<Long>();
 	}
 	
+	public HashMap<Long, List<MemoTableEntry>> getPlans() {
+		return _plans;
+	}
+	
+	public HashSet<Long> getPlansBlacklisted() {
+		return _plansBlacklist;
+	}
+	
+	public HashMap<Long, Hop> getHopRefs() {
+		return _hopRefs;
+	}
+	
 	public void addHop(Hop hop) {
 		_hopRefs.put(hop.getHopID(), hop);
 	}
@@ -78,6 +93,14 @@ public class CPlanMemoTable
 			.anyMatch(p -> (!checkClose||!p.closed) && probe.contains(p.type));
 	}
 	
+	public boolean containsNotIn(long hopID, Collection<TemplateType> types, 
+		boolean checkChildRefs, boolean excludeCell) {
+		return contains(hopID) && get(hopID).stream()
+			.anyMatch(p -> (!checkChildRefs || p.hasPlanRef()) 
+				&& (!excludeCell || p.type!=TemplateType.CELL)
+				&& !types.contains(p.type));
+	}
+	
 	public int countEntries(long hopID) {
 		return get(hopID).size();
 	}
@@ -85,7 +108,7 @@ public class CPlanMemoTable
 	public int countEntries(long hopID, TemplateType type) {
 		return (int) get(hopID).stream()
 			.filter(p -> p.type==type).count();
-	} 
+	}
 	
 	public boolean containsTopLevel(long hopID) {
 		return !_plansBlacklist.contains(hopID)
@@ -133,7 +156,7 @@ public class CPlanMemoTable
 			.distinct().collect(Collectors.toList()));
 	}
 
-	public void pruneRedundant(long hopID) {
+	public void pruneRedundant(long hopID, boolean pruneDominated, InterestingPoint[] matPoints) {
 		if( !contains(hopID) )
 			return;
 		
@@ -146,7 +169,7 @@ public class CPlanMemoTable
 		//prune dominated plans (e.g., opened plan subsumed by fused plan 
 		//if single consumer of input; however this only applies to fusion
 		//heuristic that only consider materialization points)
-		if( SpoofCompiler.PLAN_SEL_POLICY.isHeuristic() ) {
+		if( pruneDominated ) {
 			HashSet<MemoTableEntry> rmList = new HashSet<MemoTableEntry>();
 			List<MemoTableEntry> list = _plans.get(hopID);
 			Hop hop = _hopRefs.get(hopID);
@@ -155,9 +178,12 @@ public class CPlanMemoTable
 					if( e1 != e2 && e1.subsumes(e2) ) {
 						//check that childs don't have multiple consumers
 						boolean rmSafe = true; 
-						for( int i=0; i<=2; i++ )
+						for( int i=0; i<=2; i++ ) {
 							rmSafe &= (e1.isPlanRef(i) && !e2.isPlanRef(i)) ?
-								hop.getInput().get(i).getParent().size()==1 : true;
+								(matPoints!=null && !InterestingPoint.isMatPoint(
+									matPoints, hopID, e1.input(i)))
+								|| hop.getInput().get(i).getParent().size()==1 : true;
+						}
 						if( rmSafe )
 							rmList.add(e2);
 					}
@@ -194,12 +220,14 @@ public class CPlanMemoTable
 		//prune dominated plans (e.g., plan referenced by other plan and this
 		//other plan is single consumer) by marking it as blacklisted because
 		//the chain of entries is still required for cplan construction
-		for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() )
-			for( MemoTableEntry me : e.getValue() ) {
-				for( int i=0; i<=2; i++ )
-					if( me.isPlanRef(i) && _hopRefs.get(me.input(i)).getParent().size()==1 )
-						_plansBlacklist.add(me.input(i));
-			}
+		if( SpoofCompiler.PLAN_SEL_POLICY.isHeuristic() ) {
+			for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() )
+				for( MemoTableEntry me : e.getValue() ) {
+					for( int i=0; i<=2; i++ )
+						if( me.isPlanRef(i) && _hopRefs.get(me.input(i)).getParent().size()==1 )
+							_plansBlacklist.add(me.input(i));
+				}
+		}
 		
 		//core plan selection
 		PlanSelection selector = SpoofCompiler.createPlanSelector();
@@ -232,6 +260,16 @@ public class CPlanMemoTable
 			.distinct().collect(Collectors.toList());
 	}
 	
+	public List<TemplateType> getDistinctTemplateTypes(long hopID, int refAt) {
+		if(!contains(hopID))
+			return Collections.emptyList();
+		//return distinct template types with reference at given position
+		return _plans.get(hopID).stream()
+			.filter(p -> p.isPlanRef(refAt))
+			.map(p -> p.type) //extract type
+			.distinct().collect(Collectors.toList());
+	}
+	
 	public MemoTableEntry getBest(long hopID) {
 		List<MemoTableEntry> tmp = get(hopID);
 		if( tmp == null || tmp.isEmpty() )

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java
deleted file mode 100644
index f8a12fd..0000000
--- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * 
- *   http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.hops.codegen.template;
-
-import java.util.ArrayList;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-
-import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
-import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
-import org.apache.sysml.hops.rewrite.HopRewriteUtils;
-import org.apache.sysml.runtime.util.UtilFunctions;
-
-public abstract class PlanSelection 
-{
-	private final HashMap<Long, List<MemoTableEntry>> _bestPlans = 
-			new HashMap<Long, List<MemoTableEntry>>();
-	private final HashSet<VisitMark> _visited = new HashSet<VisitMark>();
-	
-	/**
-	 * Given a HOP DAG G, and a set of partial fusions plans P, find the set of optimal, 
-	 * non-conflicting fusion plans P' that applied to G minimizes costs C with
-	 * P' = \argmin_{p \subseteq P} C(G, p) s.t. Z \vDash p, where Z is a set of 
-	 * constraints such as memory budgets and block size restrictions per fused operator.
-	 * 
-	 * @param memo partial fusion plans P
-	 * @param roots entry points of HOP DAG G
-	 */
-	public abstract void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots);	
-	
-	/**
-	 * Determines if the given partial fusion plan is valid.
-	 * 
-	 * @param me memo table entry
-	 * @param hop current hop
-	 * @return true if entry is valid as top-level plan
-	 */
-	protected static boolean isValid(MemoTableEntry me, Hop hop) {
-		return (me.type == TemplateType.OuterProdTpl 
-				&& (me.closed || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop)))
-			|| (me.type == TemplateType.RowTpl)	
-			|| (me.type == TemplateType.CellTpl)
-			|| (me.type == TemplateType.MultiAggTpl);
-	}
-	
-	protected void addBestPlan(long hopID, MemoTableEntry me) {
-		if( me == null ) return;
-		if( !_bestPlans.containsKey(hopID) )
-			_bestPlans.put(hopID, new ArrayList<MemoTableEntry>());
-		_bestPlans.get(hopID).add(me);
-	}
-	
-	protected HashMap<Long, List<MemoTableEntry>> getBestPlans() {
-		return _bestPlans;
-	}
-	
-	protected boolean isVisited(long hopID, TemplateType type) {
-		return _visited.contains(new VisitMark(hopID, type));
-	}
-	
-	protected void setVisited(long hopID, TemplateType type) {
-		_visited.add(new VisitMark(hopID, type));
-	}
-	
-	/**
-	 * Basic plan comparator to compare memo table entries with regard to
-	 * a pre-defined template preference order and the number of references.
-	 */
-	protected static class BasicPlanComparator implements Comparator<MemoTableEntry> {
-		@Override
-		public int compare(MemoTableEntry o1, MemoTableEntry o2) {
-			//for different types, select preferred type
-			if( o1.type != o2.type )
-				return Integer.compare(o1.type.getRank(), o2.type.getRank());
-			
-			//for same type, prefer plan with more refs
-			return Integer.compare(
-				3-o1.countPlanRefs(), 3-o2.countPlanRefs());
-		}
-	}
-	
-	private static class VisitMark {
-		private final long _hopID;
-		private final TemplateType _type;
-		
-		public VisitMark(long hopID, TemplateType type) {
-			_hopID = hopID;
-			_type = type;
-		}
-		@Override
-		public int hashCode() {
-			return UtilFunctions.longHashCode(
-				_hopID, (_type!=null)?_type.hashCode():0);
-		}
-		@Override 
-		public boolean equals(Object o) {
-			return (o instanceof VisitMark
-				&& _hopID == ((VisitMark)o)._hopID
-				&& _type == ((VisitMark)o)._type);
-		}
-	}
-}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java
deleted file mode 100644
index a455302..0000000
--- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * 
- *   http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.hops.codegen.template;
-
-import java.util.ArrayList;
-import java.util.Comparator;
-import java.util.Map.Entry;
-import java.util.HashSet;
-import java.util.List;
-
-import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
-import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
-
-/**
- * This plan selection heuristic aims for maximal fusion, which
- * potentially leads to overlapping fused operators and thus,
- * redundant computation but with a minimal number of materialized
- * intermediate results.
- * 
- */
-public class PlanSelectionFuseAll extends PlanSelection
-{	
-	@Override
-	public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
-		//pruning and collection pass
-		for( Hop hop : roots )
-			rSelectPlans(memo, hop, null);
-		
-		//take all distinct best plans
-		for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() )
-			memo.setDistinct(e.getKey(), e.getValue());
-	}
-	
-	private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType) 
-	{	
-		if( isVisited(current.getHopID(), currentType) )
-			return;
-		
-		//step 1: prune subsumed plans of same type
-		if( memo.contains(current.getHopID()) ) {
-			HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
-			List<MemoTableEntry> hopP = memo.get(current.getHopID());
-			for( MemoTableEntry e1 : hopP )
-				for( MemoTableEntry e2 : hopP )
-					if( e1 != e2 && e1.subsumes(e2) )
-						rmSet.add(e2);
-			memo.remove(current, rmSet);
-		}
-		
-		//step 2: select plan for current path
-		MemoTableEntry best = null;
-		if( memo.contains(current.getHopID()) ) {
-			if( currentType == null ) {
-				best = memo.get(current.getHopID()).stream()
-					.filter(p -> isValid(p, current))
-					.min(new BasicPlanComparator()).orElse(null);
-			}
-			else {
-				best = memo.get(current.getHopID()).stream()
-					.filter(p -> p.type==currentType || p.type==TemplateType.CellTpl)
-					.min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs()))
-					.orElse(null);
-			}
-			addBestPlan(current.getHopID(), best);
-		}
-		
-		//step 3: recursively process children
-		for( int i=0; i< current.getInput().size(); i++ ) {
-			TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null;
-			rSelectPlans(memo, current.getInput().get(i), pref);
-		}
-		
-		setVisited(current.getHopID(), currentType);
-	}	
-}