You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/08/05 07:55:49 UTC
[2/3] systemml git commit: [SYSTEMML-1741, 1536,
1296] New cost-based codegen optimizer (V2)
http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
new file mode 100644
index 0000000..2fa0de7
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -0,0 +1,1100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Map.Entry;
+import java.util.stream.Collectors;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.hops.AggBinaryOp;
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.IndexingOp;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
+import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.TernaryOp;
+import org.apache.sysml.hops.UnaryOp;
+import org.apache.sysml.hops.codegen.opt.ReachabilityGraph.SubProblem;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
+import org.apache.sysml.hops.codegen.template.TemplateOuterProduct;
+import org.apache.sysml.hops.codegen.template.TemplateRow;
+import org.apache.sysml.hops.codegen.template.TemplateUtils;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
+import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
+import org.apache.sysml.hops.rewrite.HopRewriteUtils;
+import org.apache.sysml.runtime.codegen.LibSpoofPrimitives;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysml.utils.Statistics;
+
+/**
+ * This cost-based plan selection algorithm chooses fused operators
+ * based on the DAG structure and resulting overall costs. This includes
+ * holistic decisions on
+ * <ul>
+ * <li>Materialization points per consumer</li>
+ * <li>Sparsity exploitation and operator ordering</li>
+ * <li>Decisions on overlapping template types</li>
+ * <li>Decisions on multi-aggregates with shared reads</li>
+ * <li>Constraints (e.g., memory budgets and block sizes)</li>
+ * </ul>
+ *
+ */
+public class PlanSelectionFuseCostBasedV2 extends PlanSelection
+{
+ private static final Log LOG = LogFactory.getLog(PlanSelectionFuseCostBasedV2.class.getName());
+
+ //common bandwidth characteristics, with a conservative write bandwidth in order
+ //to cover result allocation, write into main memory, and potential evictions
+ private static final double WRITE_BANDWIDTH = 2d*1024*1024*1024; //2GB/s
+ private static final double READ_BANDWIDTH = 32d*1024*1024*1024; //32GB/s
+ private static final double COMPUTE_BANDWIDTH = 2d*1024*1024*1024 //2GFLOPs/core
+ * InfrastructureAnalyzer.getLocalParallelism();
+
+ //sparsity estimate for unknown sparsity to prefer sparse-safe fusion plans
+ private static final double SPARSE_SAFE_SPARSITY_EST = 0.1;
+
+ //optimizer configuration
+ private static final boolean USE_COST_PRUNING = true;
+ private static final boolean USE_STRUCTURAL_PRUNING = true;
+
+ private static final IDSequence COST_ID = new IDSequence();
+ private static final TemplateRow ROW_TPL = new TemplateRow();
+ private static final BasicPlanComparator BASE_COMPARE = new BasicPlanComparator();
+ private final TypedPlanComparator _typedCompare = new TypedPlanComparator();
+
+ @Override
+ public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots)
+ {
+ //step 1: analyze connected partitions (nodes, roots, mat points)
+ Collection<PlanPartition> parts = PlanAnalyzer.analyzePlanPartitions(memo, roots, true);
+
+ //step 2: optimize individual plan partitions
+ for( PlanPartition part : parts ) {
+ //create composite templates (within the partition)
+ createAndAddMultiAggPlans(memo, part.getPartition(), part.getRoots());
+
+ //plan enumeration and plan selection
+ selectPlans(memo, part);
+ }
+
+ //step 3: add composite templates (across partitions)
+ createAndAddMultiAggPlans(memo, roots);
+
+ //take all distinct best plans
+ for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() )
+ memo.setDistinct(e.getKey(), e.getValue());
+ }
+
+ private void selectPlans(CPlanMemoTable memo, PlanPartition part)
+ {
+ //prune row aggregates with pure cellwise operations
+ for( Long hopID : part.getRoots() ) {
+ MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW);
+ if( me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL)
+ && isRowTemplateWithoutAgg(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) {
+ List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW);
+ memo.remove(memo.getHopRefs().get(hopID), new HashSet<MemoTableEntry>(blacklist));
+ if( LOG.isTraceEnabled() ) {
+ LOG.trace("Removed row memo table entries w/o aggregation: "
+ + Arrays.toString(blacklist.toArray(new MemoTableEntry[0])));
+ }
+ }
+ }
+
+ //prune suboptimal outer product plans that are dominated by outer product plans w/ same number of
+ //references but better fusion properties (e.g., for the patterns Y=X*(U%*%t(V)) and sum(Y*(U2%*%t(V2))),
+ //we'd prune sum(X*(U%*%t(V))*Z), Z=U2%*%t(V2) because this would unnecessarily destroy a fusion pattern.
+ for( Long hopID : part.getPartition() ) {
+ if( memo.countEntries(hopID, TemplateType.OUTER) == 2 ) {
+ List<MemoTableEntry> entries = memo.get(hopID, TemplateType.OUTER);
+ MemoTableEntry me1 = entries.get(0);
+ MemoTableEntry me2 = entries.get(1);
+ MemoTableEntry rmEntry = TemplateOuterProduct.dropAlternativePlan(memo, me1, me2);
+ if( rmEntry != null ) {
+ memo.remove(memo.getHopRefs().get(hopID), Collections.singleton(rmEntry));
+ memo.getPlansBlacklisted().remove(rmEntry.input(rmEntry.getPlanRefIndex()));
+ if( LOG.isTraceEnabled() )
+ LOG.trace("Removed dominated outer product memo table entry: " + rmEntry);
+ }
+ }
+ }
+
+ //if no materialization points, use basic fuse-all w/ partition awareness
+ if( part.getMatPointsExt() == null || part.getMatPointsExt().length==0 ) {
+ for( Long hopID : part.getRoots() )
+ rSelectPlansFuseAll(memo,
+ memo.getHopRefs().get(hopID), null, part.getPartition());
+ }
+ else {
+ //obtain hop compute costs per cell once
+ HashMap<Long, Double> computeCosts = new HashMap<Long, Double>();
+ for( Long hopID : part.getRoots() )
+ rGetComputeCosts(memo.getHopRefs().get(hopID), part.getPartition(), computeCosts);
+
+ //prepare pruning helpers and prune memo table w/ determined mat points
+ StaticCosts costs = new StaticCosts(computeCosts, getComputeCost(computeCosts, memo),
+ getReadCost(part, memo), getWriteCost(part.getRoots(), memo));
+ ReachabilityGraph rgraph = USE_STRUCTURAL_PRUNING ? new ReachabilityGraph(part, memo) : null;
+ if( USE_STRUCTURAL_PRUNING ) {
+ part.setMatPointsExt(rgraph.getSortedSearchSpace());
+ for( Long hopID : part.getPartition() )
+ memo.pruneRedundant(hopID, true, part.getMatPointsExt());
+ }
+
+ //enumerate and cost plans, returns optional plan
+ boolean[] bestPlan = enumPlans(memo, part, costs, rgraph,
+ part.getMatPointsExt(), 0, Double.MAX_VALUE);
+
+ //prune memo table wrt best plan and select plans
+ HashSet<Long> visited = new HashSet<Long>();
+ for( Long hopID : part.getRoots() )
+ rPruneSuboptimalPlans(memo, memo.getHopRefs().get(hopID),
+ visited, part, part.getMatPointsExt(), bestPlan);
+ HashSet<Long> visited2 = new HashSet<Long>();
+ for( Long hopID : part.getRoots() )
+ rPruneInvalidPlans(memo, memo.getHopRefs().get(hopID),
+ visited2, part, bestPlan);
+
+ for( Long hopID : part.getRoots() )
+ rSelectPlansFuseAll(memo,
+ memo.getHopRefs().get(hopID), null, part.getPartition());
+ }
+ }
+
+ /**
+ * Core plan enumeration algorithm, invoked recursively for conditionally independent
+ * subproblems. This algorithm fully explores the exponential search space of 2^m,
+ * where m is the number of interesting materialization points. We iterate over
+ * a linearized search space without every instantiating the search tree. Furthermore,
+ * in order to reduce the enumeration overhead, we apply two high-impact pruning
+ * techniques (1) pruning by evolving lower/upper cost bounds, and (2) pruning by
+ * conditional structural properties (so-called cutsets of interesting points).
+ *
+ * @param memo memoization table of partial fusion plans
+ * @param part connected component (partition) of partial fusion plans with all necessary meta data
+ * @param costs summary of static costs (e.g., partition reads, writes, and compute costs per operator)
+ * @param rgraph reachability graph of interesting materialization points
+ * @param matPoints sorted materialization points (defined the search space)
+ * @param off offset for recursive invocation, indicating the fixed plan part
+ * @param bestC currently known best plan costs (used of upper bound)
+ * @return optimal assignment of materialization points
+ */
+ private static boolean[] enumPlans(CPlanMemoTable memo, PlanPartition part, StaticCosts costs,
+ ReachabilityGraph rgraph, InterestingPoint[] matPoints, int off, double bestC)
+ {
+ //scan linearized search space, w/ skips for branch and bound pruning
+ //and structural pruning (where we solve conditionally independent problems)
+ //bestC is monotonically non-increasing and serves as the upper bound
+ long len = (long)Math.pow(2, matPoints.length-off);
+ boolean[] bestPlan = null;
+ int numEvalPlans = 0;
+
+ for( long i=0; i<len; i++ ) {
+ //construct assignment
+ boolean[] plan = createAssignment(matPoints.length-off, off, i);
+ long pskip = 0; //skip after costing
+
+ //skip plans with structural pruning
+ if( USE_STRUCTURAL_PRUNING && (rgraph!=null) && rgraph.isCutSet(plan) ) {
+ //compute skip (which also acts as boundary for subproblems)
+ pskip = rgraph.getNumSkipPlans(plan);
+
+ //start increment rgraph get subproblems
+ SubProblem[] prob = rgraph.getSubproblems(plan);
+
+ //solve subproblems independently and combine into best plan
+ for( int j=0; j<prob.length; j++ ) {
+ boolean[] bestTmp = enumPlans(memo, part,
+ costs, null, prob[j].freeMat, prob[j].offset, bestC);
+ LibSpoofPrimitives.vectWrite(bestTmp, plan, prob[j].freePos);
+ }
+
+ //note: the overall plan costs are evaluated in full, which reused
+ //the default code path; hence we postpone the skip after costing
+ }
+ //skip plans with branch and bound pruning (cost)
+ else if( USE_COST_PRUNING ) {
+ double lbC = Math.max(costs._read, costs._compute) + costs._write
+ + getMaterializationCost(part, matPoints, memo, plan);
+ if( lbC >= bestC ) {
+ long skip = getNumSkipPlans(plan);
+ if( LOG.isTraceEnabled() )
+ LOG.trace("Enum: Skip "+skip+" plans (by cost).");
+ i += skip - 1;
+ continue;
+ }
+ }
+
+ //cost assignment on hops
+ double C = getPlanCost(memo, part, matPoints, plan, costs._computeCosts);
+ numEvalPlans ++;
+ if( LOG.isTraceEnabled() )
+ LOG.trace("Enum: "+Arrays.toString(plan)+" -> "+C);
+
+ //cost comparisons
+ if( bestPlan == null || C < bestC ) {
+ bestC = C;
+ bestPlan = plan;
+ if( LOG.isTraceEnabled() )
+ LOG.trace("Enum: Found new best plan.");
+ }
+
+ //post skipping
+ i += pskip;
+ if( pskip !=0 && LOG.isTraceEnabled() )
+ LOG.trace("Enum: Skip "+pskip+" plans (by structure).");
+ }
+
+ if( DMLScript.STATISTICS )
+ Statistics.incrementCodegenFPlanCompile(numEvalPlans);
+ if( LOG.isTraceEnabled() )
+ LOG.trace("Enum: Optimal plan: "+Arrays.toString(bestPlan));
+
+ //copy best plan w/o fixed offset plan
+ return Arrays.copyOfRange(bestPlan, off, bestPlan.length);
+ }
+
+ private static boolean[] createAssignment(int len, int off, long pos) {
+ boolean[] ret = new boolean[off+len];
+ Arrays.fill(ret, 0, off, true);
+ long tmp = pos;
+ for( int i=0; i<len; i++ ) {
+ ret[off+i] = (tmp >= Math.pow(2, len-i-1));
+ tmp %= Math.pow(2, len-i-1);
+ }
+ return ret;
+ }
+
+ private static long getNumSkipPlans(boolean[] plan) {
+ int pos = ArrayUtils.lastIndexOf(plan, true);
+ return (long) Math.pow(2, plan.length-pos-1);
+ }
+
+ private static double getMaterializationCost(PlanPartition part, InterestingPoint[] M, CPlanMemoTable memo, boolean[] plan) {
+ double costs = 0;
+ //currently active materialization points
+ HashSet<Long> matTargets = new HashSet<>();
+ for( int i=0; i<plan.length; i++ ) {
+ long hopID = M[i].getToHopID();
+ if( plan[i] && !matTargets.contains(hopID) ) {
+ matTargets.add(hopID);
+ Hop hop = memo.getHopRefs().get(hopID);
+ long size = getSize(hop);
+ costs += size * 8 / WRITE_BANDWIDTH +
+ size * 8 / READ_BANDWIDTH;
+ }
+ }
+ //points with non-partition consumers
+ for( Long hopID : part.getExtConsumed() )
+ if( !matTargets.contains(hopID) ) {
+ matTargets.add(hopID);
+ Hop hop = memo.getHopRefs().get(hopID);
+ costs += getSize(hop) * 8 / WRITE_BANDWIDTH;
+ }
+
+ return costs;
+ }
+
+ private static double getReadCost(PlanPartition part, CPlanMemoTable memo) {
+ double costs = 0;
+ //get partition input reads (at least read once)
+ for( Long hopID : part.getInputs() ) {
+ Hop hop = memo.getHopRefs().get(hopID);
+ costs += getSize(hop) * 8 / READ_BANDWIDTH;
+ }
+ return costs;
+ }
+
+ private static double getWriteCost(Collection<Long> R, CPlanMemoTable memo) {
+ double costs = 0;
+ for( Long hopID : R ) {
+ Hop hop = memo.getHopRefs().get(hopID);
+ costs += getSize(hop) * 8 / WRITE_BANDWIDTH;
+ }
+ return costs;
+ }
+
+ private static double getComputeCost(HashMap<Long, Double> computeCosts, CPlanMemoTable memo) {
+ double costs = 0;
+ for( Entry<Long,Double> e : computeCosts.entrySet() ) {
+ Hop mainInput = memo.getHopRefs()
+ .get(e.getKey()).getInput().get(0);
+ costs += getSize(mainInput) * e.getValue() / COMPUTE_BANDWIDTH;
+ }
+ return costs;
+ }
+
+ private static long getSize(Hop hop) {
+ return Math.max(hop.getDim1(),1)
+ * Math.max(hop.getDim2(),1);
+ }
+
+ //within-partition multi-agg templates
+ private static void createAndAddMultiAggPlans(CPlanMemoTable memo, HashSet<Long> partition, HashSet<Long> R)
+ {
+ //create index of plans that reference full aggregates to avoid circular dependencies
+ HashSet<Long> refHops = new HashSet<Long>();
+ for( Entry<Long, List<MemoTableEntry>> e : memo.getPlans().entrySet() )
+ if( !e.getValue().isEmpty() ) {
+ Hop hop = memo.getHopRefs().get(e.getKey());
+ for( Hop c : hop.getInput() )
+ refHops.add(c.getHopID());
+ }
+
+ //find all full aggregations (the fact that they are in the same partition guarantees
+ //that they also have common subexpressions, also full aggregations are by def root nodes)
+ ArrayList<Long> fullAggs = new ArrayList<Long>();
+ for( Long hopID : R ) {
+ Hop root = memo.getHopRefs().get(hopID);
+ if( !refHops.contains(hopID) && isMultiAggregateRoot(root) )
+ fullAggs.add(hopID);
+ }
+ if( LOG.isTraceEnabled() ) {
+ LOG.trace("Found within-partition ua(RC) aggregations: " +
+ Arrays.toString(fullAggs.toArray(new Long[0])));
+ }
+
+ //construct and add multiagg template plans (w/ max 3 aggregations)
+ for( int i=0; i<fullAggs.size(); i+=3 ) {
+ int ito = Math.min(i+3, fullAggs.size());
+ if( ito-i >= 2 ) {
+ MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG,
+ fullAggs.get(i), fullAggs.get(i+1), ((ito-i)==3)?fullAggs.get(i+2):-1, ito-i);
+ if( isValidMultiAggregate(memo, me) ) {
+ for( int j=i; j<ito; j++ ) {
+ memo.add(memo.getHopRefs().get(fullAggs.get(j)), me);
+ if( LOG.isTraceEnabled() )
+ LOG.trace("Added multiagg plan: "+fullAggs.get(j)+" "+me);
+ }
+ }
+ else if( LOG.isTraceEnabled() ) {
+ LOG.trace("Removed invalid multiagg plan: "+me);
+ }
+ }
+ }
+ }
+
+ //across-partition multi-agg templates with shared reads
+ private void createAndAddMultiAggPlans(CPlanMemoTable memo, ArrayList<Hop> roots)
+ {
+ //collect full aggregations as initial set of candidates
+ HashSet<Long> fullAggs = new HashSet<Long>();
+ Hop.resetVisitStatus(roots);
+ for( Hop hop : roots )
+ rCollectFullAggregates(hop, fullAggs);
+ Hop.resetVisitStatus(roots);
+
+ //remove operators with assigned multi-agg plans
+ fullAggs.removeIf(p -> memo.contains(p, TemplateType.MAGG));
+
+ //check applicability for further analysis
+ if( fullAggs.size() <= 1 )
+ return;
+
+ if( LOG.isTraceEnabled() ) {
+ LOG.trace("Found across-partition ua(RC) aggregations: " +
+ Arrays.toString(fullAggs.toArray(new Long[0])));
+ }
+
+ //collect information for all candidates
+ //(subsumed aggregations, and inputs to fused operators)
+ List<AggregateInfo> aggInfos = new ArrayList<AggregateInfo>();
+ for( Long hopID : fullAggs ) {
+ Hop aggHop = memo.getHopRefs().get(hopID);
+ AggregateInfo tmp = new AggregateInfo(aggHop);
+ for( int i=0; i<aggHop.getInput().size(); i++ ) {
+ Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i==0 ?
+ aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i);
+ rExtractAggregateInfo(memo, c, tmp, TemplateType.CELL);
+ }
+ if( tmp._fusedInputs.isEmpty() ) {
+ if( HopRewriteUtils.isMatrixMultiply(aggHop) ) {
+ tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID());
+ tmp.addFusedInput(aggHop.getInput().get(1).getHopID());
+ }
+ else
+ tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
+ }
+ aggInfos.add(tmp);
+ }
+
+ if( LOG.isTraceEnabled() ) {
+ LOG.trace("Extracted across-partition ua(RC) aggregation info: ");
+ for( AggregateInfo info : aggInfos )
+ LOG.trace(info);
+ }
+
+ //sort aggregations by num dependencies to simplify merging
+ //clusters of aggregations with parallel dependencies
+ aggInfos = aggInfos.stream()
+ .sorted(Comparator.comparing(a -> a._inputAggs.size()))
+ .collect(Collectors.toList());
+
+ //greedy grouping of multi-agg candidates
+ boolean converged = false;
+ while( !converged ) {
+ AggregateInfo merged = null;
+ for( int i=0; i<aggInfos.size(); i++ ) {
+ AggregateInfo current = aggInfos.get(i);
+ for( int j=i+1; j<aggInfos.size(); j++ ) {
+ AggregateInfo that = aggInfos.get(j);
+ if( current.isMergable(that) ) {
+ merged = current.merge(that);
+ aggInfos.remove(j); j--;
+ }
+ }
+ }
+ converged = (merged == null);
+ }
+
+ if( LOG.isTraceEnabled() ) {
+ LOG.trace("Merged across-partition ua(RC) aggregation info: ");
+ for( AggregateInfo info : aggInfos )
+ LOG.trace(info);
+ }
+
+ //construct and add multiagg template plans (w/ max 3 aggregations)
+ for( AggregateInfo info : aggInfos ) {
+ if( info._aggregates.size()<=1 )
+ continue;
+ Long[] aggs = info._aggregates.keySet().toArray(new Long[0]);
+ MemoTableEntry me = new MemoTableEntry(TemplateType.MAGG,
+ aggs[0], aggs[1], (aggs.length>2)?aggs[2]:-1, aggs.length);
+ for( int i=0; i<aggs.length; i++ ) {
+ memo.add(memo.getHopRefs().get(aggs[i]), me);
+ addBestPlan(aggs[i], me);
+ if( LOG.isTraceEnabled() )
+ LOG.trace("Added multiagg* plan: "+aggs[i]+" "+me);
+
+ }
+ }
+ }
+
+ private static boolean isMultiAggregateRoot(Hop root) {
+ return (HopRewriteUtils.isAggUnaryOp(root, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX)
+ && ((AggUnaryOp)root).getDirection()==Direction.RowCol)
+ || (root instanceof AggBinaryOp && root.getDim1()==1 && root.getDim2()==1
+ && HopRewriteUtils.isTransposeOperation(root.getInput().get(0)));
+ }
+
+ private static boolean isValidMultiAggregate(CPlanMemoTable memo, MemoTableEntry me) {
+ //ensure input consistent sizes (otherwise potential for incorrect results)
+ boolean ret = true;
+ Hop refSize = memo.getHopRefs().get(me.input1).getInput().get(0);
+ for( int i=1; ret && i<3; i++ ) {
+ if( me.isPlanRef(i) )
+ ret &= HopRewriteUtils.isEqualSize(refSize,
+ memo.getHopRefs().get(me.input(i)).getInput().get(0));
+ }
+
+ //ensure that aggregates are independent of each other, i.e.,
+ //they to not have potentially transitive parent child references
+ for( int i=0; ret && i<3; i++ )
+ if( me.isPlanRef(i) ) {
+ HashSet<Long> probe = new HashSet<Long>();
+ for( int j=0; j<3; j++ )
+ if( i != j )
+ probe.add(me.input(j));
+ ret &= rCheckMultiAggregate(memo.getHopRefs().get(me.input(i)), probe);
+ }
+ return ret;
+ }
+
+ private static boolean rCheckMultiAggregate(Hop current, HashSet<Long> probe) {
+ boolean ret = true;
+ for( Hop c : current.getInput() )
+ ret &= rCheckMultiAggregate(c, probe);
+ ret &= !probe.contains(current.getHopID());
+ return ret;
+ }
+
+ private static void rCollectFullAggregates(Hop current, HashSet<Long> aggs) {
+ if( current.isVisited() )
+ return;
+
+ //collect all applicable full aggregations per read
+ if( isMultiAggregateRoot(current) )
+ aggs.add(current.getHopID());
+
+ //recursively process children
+ for( Hop c : current.getInput() )
+ rCollectFullAggregates(c, aggs);
+
+ current.setVisited();
+ }
+
+ private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop current, AggregateInfo aggInfo, TemplateType type) {
+ //collect input aggregates (dependents)
+ if( isMultiAggregateRoot(current) )
+ aggInfo.addInputAggregate(current.getHopID());
+
+ //recursively process children
+ MemoTableEntry me = (type!=null) ? memo.getBest(current.getHopID()) : null;
+ for( int i=0; i<current.getInput().size(); i++ ) {
+ Hop c = current.getInput().get(i);
+ if( me != null && me.isPlanRef(i) )
+ rExtractAggregateInfo(memo, c, aggInfo, type);
+ else {
+ if( type != null && c.getDataType().isMatrix() ) //add fused input
+ aggInfo.addFusedInput(c.getHopID());
+ rExtractAggregateInfo(memo, c, aggInfo, null);
+ }
+ }
+ }
+
+ private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
+ //consider all aggregations other than root operation
+ MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
+ boolean ret = true;
+ for(int i=0; i<3; i++)
+ if( me.isPlanRef(i) )
+ ret &= rIsRowTemplateWithoutAgg(memo,
+ current.getInput().get(i), visited);
+ return ret;
+ }
+
+ private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
+ if( visited.contains(current.getHopID()) )
+ return true;
+
+ boolean ret = true;
+ MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
+ for(int i=0; i<3; i++)
+ if( me.isPlanRef(i) )
+ ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited);
+ ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp);
+
+ visited.add(current.getHopID());
+ return ret;
+ }
+
+ private static void rPruneSuboptimalPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited,
+ PlanPartition part, InterestingPoint[] matPoints, boolean[] plan)
+ {
+ //memoization (not via hops because in middle of dag)
+ if( visited.contains(current.getHopID()) )
+ return;
+
+ //remove memo table entries if necessary
+ long hopID = current.getHopID();
+ if( part.getPartition().contains(hopID) && memo.contains(hopID) ) {
+ Iterator<MemoTableEntry> iter = memo.get(hopID).iterator();
+ while( iter.hasNext() ) {
+ MemoTableEntry me = iter.next();
+ if( !hasNoRefToMatPoint(hopID, me, matPoints, plan) && me.type!=TemplateType.OUTER ) {
+ iter.remove();
+ if( LOG.isTraceEnabled() )
+ LOG.trace("Removed memo table entry: "+me);
+ }
+ }
+ }
+
+ //process children recursively
+ for( Hop c : current.getInput() )
+ rPruneSuboptimalPlans(memo, c, visited, part, matPoints, plan);
+
+ visited.add(current.getHopID());
+ }
+
+ private static void rPruneInvalidPlans(CPlanMemoTable memo, Hop current, HashSet<Long> visited, PlanPartition part, boolean[] plan) {
+ //memoization (not via hops because in middle of dag)
+ if( visited.contains(current.getHopID()) )
+ return;
+
+ //process children recursively
+ for( Hop c : current.getInput() )
+ rPruneInvalidPlans(memo, c, visited, part, plan);
+
+ //find invalid row aggregate leaf nodes (see TemplateRow.open) w/o matrix inputs,
+ //i.e., plans that become invalid after the previous pruning step
+ long hopID = current.getHopID();
+ if( part.getPartition().contains(hopID) && memo.contains(hopID, TemplateType.ROW) ) {
+ for( MemoTableEntry me : memo.get(hopID) ) {
+ if( me.type==TemplateType.ROW ) {
+ //convert leaf node with pure vector inputs
+ if( !me.hasPlanRef() && !TemplateUtils.hasMatrixInput(current) ) {
+ me.type = TemplateType.CELL;
+ if( LOG.isTraceEnabled() )
+ LOG.trace("Converted leaf memo table entry from row to cell: "+me);
+ }
+
+ //convert inner node without row template input
+ if( me.hasPlanRef() && !ROW_TPL.open(current) ) {
+ boolean hasRowInput = false;
+ for( int i=0; i<3; i++ )
+ if( me.isPlanRef(i) )
+ hasRowInput |= memo.contains(me.input(i), TemplateType.ROW);
+ if( !hasRowInput ) {
+ me.type = TemplateType.CELL;
+ if( LOG.isTraceEnabled() )
+ LOG.trace("Converted inner memo table entry from row to cell: "+me);
+ }
+ }
+
+ }
+ }
+ }
+
+ visited.add(current.getHopID());
+ }
+
+ private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateType currentType, HashSet<Long> partition)
+ {
+ if( isVisited(current.getHopID(), currentType)
+ || !partition.contains(current.getHopID()) )
+ return;
+
+ //step 1: prune subsumed plans of same type
+ if( memo.contains(current.getHopID()) ) {
+ HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
+ List<MemoTableEntry> hopP = memo.get(current.getHopID());
+ for( MemoTableEntry e1 : hopP )
+ for( MemoTableEntry e2 : hopP )
+ if( e1 != e2 && e1.subsumes(e2) )
+ rmSet.add(e2);
+ memo.remove(current, rmSet);
+ }
+
+ //step 2: select plan for current path
+ MemoTableEntry best = null;
+ if( memo.contains(current.getHopID()) ) {
+ if( currentType == null ) {
+ best = memo.get(current.getHopID()).stream()
+ .filter(p -> isValid(p, current))
+ .min(BASE_COMPARE).orElse(null);
+ }
+ else {
+ _typedCompare.setType(currentType);
+ best = memo.get(current.getHopID()).stream()
+ .filter(p -> p.type==currentType || p.type==TemplateType.CELL)
+ .min(_typedCompare).orElse(null);
+ }
+ addBestPlan(current.getHopID(), best);
+ }
+
+ //step 3: recursively process children
+ for( int i=0; i< current.getInput().size(); i++ ) {
+ TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null;
+ rSelectPlansFuseAll(memo, current.getInput().get(i), pref, partition);
+ }
+
+ setVisited(current.getHopID(), currentType);
+ }
+
+ /////////////////////////////////////////////////////////
+ // Cost model fused operators w/ materialization points
+ //////////
+
+ private static double getPlanCost(CPlanMemoTable memo, PlanPartition part,
+ InterestingPoint[] matPoints,boolean[] plan, HashMap<Long, Double> computeCosts)
+ {
+ //high level heuristic: every hop or fused operator has the following cost:
+ //WRITE + max(COMPUTE, READ), where WRITE costs are given by the output size,
+ //READ costs by the input sizes, and COMPUTE by operation specific FLOP
+ //counts times number of cells of main input, disregarding sparsity for now.
+
+ HashSet<VisitMarkCost> visited = new HashSet<>();
+ double costs = 0;
+ for( Long hopID : part.getRoots() ) {
+ costs += rGetPlanCosts(memo, memo.getHopRefs().get(hopID),
+ visited, part, matPoints, plan, computeCosts, null, null);
+ }
+ return costs;
+ }
+
+ private static double rGetPlanCosts(CPlanMemoTable memo, Hop current, HashSet<VisitMarkCost> visited,
+ PlanPartition part, InterestingPoint[] matPoints, boolean[] plan, HashMap<Long, Double> computeCosts,
+ CostVector costsCurrent, TemplateType currentType)
+ {
+ //memoization per hop id and cost vector to account for redundant
+ //computation without double counting materialized results or compute
+ //costs of complex operation DAGs within a single fused operator
+ VisitMarkCost tag = new VisitMarkCost(current.getHopID(),
+ (costsCurrent==null || currentType==TemplateType.MAGG)?0:costsCurrent.ID);
+ if( visited.contains(tag) )
+ return 0;
+ visited.add(tag);
+
+ //open template if necessary, including memoization
+ //under awareness of current plan choice
+ MemoTableEntry best = null;
+ boolean opened = false;
+ if( memo.contains(current.getHopID()) ) {
+ //note: this is the inner loop of plan enumeration and hence, we do not
+ //use streams, lambda expressions, etc to avoid unnecessary overhead
+ long hopID = current.getHopID();
+ if( currentType == null ) {
+ for( MemoTableEntry me : memo.get(hopID) )
+ best = isValid(me, current)
+ && hasNoRefToMatPoint(hopID, me, matPoints, plan)
+ && BasicPlanComparator.icompare(me, best)<0 ? me : best;
+ opened = true;
+ }
+ else {
+ for( MemoTableEntry me : memo.get(hopID) )
+ best = (me.type == currentType || me.type==TemplateType.CELL)
+ && hasNoRefToMatPoint(hopID, me, matPoints, plan)
+ && TypedPlanComparator.icompare(me, best, currentType)<0 ? me : best;
+ }
+ }
+
+ //create new cost vector if opened, initialized with write costs
+ CostVector costVect = !opened ? costsCurrent : new CostVector(getSize(current));
+ double costs = 0;
+
+ //add other roots for multi-agg template to account for shared costs
+ if( opened && best != null && best.type == TemplateType.MAGG ) {
+ //account costs to first multi-agg root
+ if( best.input1 == current.getHopID() )
+ for( int i=1; i<3; i++ ) {
+ if( !best.isPlanRef(i) ) continue;
+ costs += rGetPlanCosts(memo, memo.getHopRefs().get(best.input(i)), visited,
+ part, matPoints, plan, computeCosts, costVect, TemplateType.MAGG);
+ }
+ //skip other multi-agg roots
+ else
+ return 0;
+ }
+
+ //add compute costs of current operator to costs vector
+ if( part.getPartition().contains(current.getHopID()) )
+ costVect.computeCosts += computeCosts.get(current.getHopID());
+
+ //process children recursively
+ for( int i=0; i< current.getInput().size(); i++ ) {
+ Hop c = current.getInput().get(i);
+ if( best!=null && best.isPlanRef(i) )
+ costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, costVect, best.type);
+ else if( best!=null && isImplicitlyFused(current, i, best.type) )
+ costVect.addInputSize(c.getInput().get(0).getHopID(), getSize(c));
+ else { //include children and I/O costs
+ costs += rGetPlanCosts(memo, c, visited, part, matPoints, plan, computeCosts, null, null);
+ if( costVect != null && c.getDataType().isMatrix() )
+ costVect.addInputSize(c.getHopID(), getSize(c));
+ }
+ }
+
+ //add costs for opened fused operator
+ if( part.getPartition().contains(current.getHopID()) ) {
+ if( opened ) {
+ if( LOG.isTraceEnabled() ) {
+ String type = (best !=null) ? best.type.name() : "HOP";
+ LOG.trace("Cost vector ("+type+" "+current.getHopID()+"): "+costVect);
+ }
+ double tmpCosts = costVect.outSize * 8 / WRITE_BANDWIDTH //time for output write
+ + Math.max(costVect.getSumInputSizes() * 8 / READ_BANDWIDTH,
+ costVect.computeCosts*costVect.getMaxInputSize()/ COMPUTE_BANDWIDTH);
+ //sparsity correction for outer-product template (and sparse-safe cell)
+ if( best != null && best.type == TemplateType.OUTER ) {
+ Hop driver = memo.getHopRefs().get(costVect.getMaxInputSizeHopID());
+ tmpCosts *= driver.dimsKnown(true) ? driver.getSparsity() : SPARSE_SAFE_SPARSITY_EST;
+ }
+ costs += tmpCosts;
+ }
+ //add costs for non-partition read in the middle of fused operator
+ else if( part.getExtConsumed().contains(current.getHopID()) ) {
+ costs += rGetPlanCosts(memo, current, visited,
+ part, matPoints, plan, computeCosts, null, null);
+ }
+ }
+
+ //sanity check non-negative costs
+ if( costs < 0 || Double.isNaN(costs) || Double.isInfinite(costs) )
+ throw new RuntimeException("Wrong cost estimate: "+costs);
+
+ return costs;
+ }
+
+ private static void rGetComputeCosts(Hop current, HashSet<Long> partition, HashMap<Long, Double> computeCosts)
+ {
+ if( computeCosts.containsKey(current.getHopID())
+ || !partition.contains(current.getHopID()) )
+ return;
+
+ //recursively process children
+ for( Hop c : current.getInput() )
+ rGetComputeCosts(c, partition, computeCosts);
+
+ //get costs for given hop
+ double costs = 1;
+ if( current instanceof UnaryOp ) {
+ switch( ((UnaryOp)current).getOp() ) {
+ case ABS:
+ case ROUND:
+ case CEIL:
+ case FLOOR:
+ case SIGN:
+ case SELP: costs = 1; break;
+ case SPROP:
+ case SQRT: costs = 2; break;
+ case EXP: costs = 18; break;
+ case SIGMOID: costs = 21; break;
+ case LOG:
+ case LOG_NZ: costs = 32; break;
+ case NCOL:
+ case NROW:
+ case PRINT:
+ case CAST_AS_BOOLEAN:
+ case CAST_AS_DOUBLE:
+ case CAST_AS_INT:
+ case CAST_AS_MATRIX:
+ case CAST_AS_SCALAR: costs = 1; break;
+ case SIN: costs = 18; break;
+ case COS: costs = 22; break;
+ case TAN: costs = 42; break;
+ case ASIN: costs = 93; break;
+ case ACOS: costs = 103; break;
+ case ATAN: costs = 40; break;
+ case CUMSUM:
+ case CUMMIN:
+ case CUMMAX:
+ case CUMPROD: costs = 1; break;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for: "+((UnaryOp)current).getOp());
+ }
+ }
+ else if( current instanceof BinaryOp ) {
+ switch( ((BinaryOp)current).getOp() ) {
+ case MULT:
+ case PLUS:
+ case MINUS:
+ case MIN:
+ case MAX:
+ case AND:
+ case OR:
+ case EQUAL:
+ case NOTEQUAL:
+ case LESS:
+ case LESSEQUAL:
+ case GREATER:
+ case GREATEREQUAL:
+ case CBIND:
+ case RBIND: costs = 1; break;
+ case INTDIV: costs = 6; break;
+ case MODULUS: costs = 8; break;
+ case DIV: costs = 22; break;
+ case LOG:
+ case LOG_NZ: costs = 32; break;
+ case POW: costs = (HopRewriteUtils.isLiteralOfValue(
+ current.getInput().get(1), 2) ? 1 : 16); break;
+ case MINUS_NZ:
+ case MINUS1_MULT: costs = 2; break;
+ case CENTRALMOMENT:
+ int type = (int) (current.getInput().get(1) instanceof LiteralOp ?
+ HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
+ switch( type ) {
+ case 0: costs = 1; break; //count
+ case 1: costs = 8; break; //mean
+ case 2: costs = 16; break; //cm2
+ case 3: costs = 31; break; //cm3
+ case 4: costs = 51; break; //cm4
+ case 5: costs = 16; break; //variance
+ }
+ break;
+ case COVARIANCE: costs = 23; break;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for: "+((BinaryOp)current).getOp());
+ }
+ }
+ else if( current instanceof TernaryOp ) {
+ switch( ((TernaryOp)current).getOp() ) {
+ case PLUS_MULT:
+ case MINUS_MULT: costs = 2; break;
+ case CTABLE: costs = 3; break;
+ case CENTRALMOMENT:
+ int type = (int) (current.getInput().get(1) instanceof LiteralOp ?
+ HopRewriteUtils.getIntValueSafe((LiteralOp)current.getInput().get(1)) : 2);
+ switch( type ) {
+ case 0: costs = 2; break; //count
+ case 1: costs = 9; break; //mean
+ case 2: costs = 17; break; //cm2
+ case 3: costs = 32; break; //cm3
+ case 4: costs = 52; break; //cm4
+ case 5: costs = 17; break; //variance
+ }
+ break;
+ case COVARIANCE: costs = 23; break;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for: "+((TernaryOp)current).getOp());
+ }
+ }
+ else if( current instanceof ParameterizedBuiltinOp ) {
+ costs = 1;
+ }
+ else if( current instanceof IndexingOp ) {
+ costs = 1;
+ }
+ else if( current instanceof ReorgOp ) {
+ costs = 1;
+ }
+ else if( current instanceof AggBinaryOp ) {
+ //outer product template
+ if( HopRewriteUtils.isOuterProductLikeMM(current) )
+ costs = 2 * current.getInput().get(0).getDim2();
+ //row template w/ matrix-vector or matrix-matrix
+ else
+ costs = 2 * current .getDim2();
+ }
+ else if( current instanceof AggUnaryOp) {
+ switch(((AggUnaryOp)current).getOp()) {
+ case SUM: costs = 4; break;
+ case SUM_SQ: costs = 5; break;
+ case MIN:
+ case MAX: costs = 1; break;
+ default:
+ LOG.warn("Cost model not "
+ + "implemented yet for: "+((AggUnaryOp)current).getOp());
+ }
+ }
+
+ computeCosts.put(current.getHopID(), costs);
+ }
+
+ private static boolean hasNoRefToMatPoint(long hopID,
+ MemoTableEntry me, InterestingPoint[] M, boolean[] plan) {
+ return !InterestingPoint.isMatPoint(M, hopID, me, plan);
+ }
+
+ private static boolean isImplicitlyFused(Hop hop, int index, TemplateType type) {
+ return type == TemplateType.ROW
+ && HopRewriteUtils.isMatrixMultiply(hop) && index==0
+ && HopRewriteUtils.isTransposeOperation(hop.getInput().get(index));
+ }
+
+ private static class CostVector {
+ public final long ID;
+ public final double outSize;
+ public double computeCosts = 0;
+ public final HashMap<Long, Double> inSizes = new HashMap<Long, Double>();
+
+ public CostVector(double outputSize) {
+ ID = COST_ID.getNextID();
+ outSize = outputSize;
+ }
+ public void addInputSize(long hopID, double inputSize) {
+ //ensures that input sizes are not double counted
+ inSizes.put(hopID, inputSize);
+ }
+ public double getSumInputSizes() {
+ return inSizes.values().stream()
+ .mapToDouble(d -> d.doubleValue()).sum();
+ }
+ public double getMaxInputSize() {
+ return inSizes.values().stream()
+ .mapToDouble(d -> d.doubleValue()).max().orElse(0);
+ }
+ public long getMaxInputSizeHopID() {
+ long id = -1; double max = 0;
+ for( Entry<Long,Double> e : inSizes.entrySet() )
+ if( max < e.getValue() ) {
+ id = e.getKey();
+ max = e.getValue();
+ }
+ return id;
+ }
+ @Override
+ public String toString() {
+ return "["+outSize+", "+computeCosts+", {"
+ +Arrays.toString(inSizes.keySet().toArray(new Long[0]))+", "
+ +Arrays.toString(inSizes.values().toArray(new Double[0]))+"}]";
+ }
+ }
+
+ private static class StaticCosts {
+ public final HashMap<Long, Double> _computeCosts;
+ public final double _compute;
+ public final double _read;
+ public final double _write;
+
+ public StaticCosts(HashMap<Long,Double> allComputeCosts, double computeCost, double readCost, double writeCost) {
+ _computeCosts = allComputeCosts;
+ _compute = computeCost;
+ _read = readCost;
+ _write = writeCost;
+ }
+ }
+
+ private static class AggregateInfo {
+ public final HashMap<Long,Hop> _aggregates;
+ public final HashSet<Long> _inputAggs = new HashSet<Long>();
+ public final HashSet<Long> _fusedInputs = new HashSet<Long>();
+ public AggregateInfo(Hop aggregate) {
+ _aggregates = new HashMap<Long, Hop>();
+ _aggregates.put(aggregate.getHopID(), aggregate);
+ }
+ public void addInputAggregate(long hopID) {
+ _inputAggs.add(hopID);
+ }
+ public void addFusedInput(long hopID) {
+ _fusedInputs.add(hopID);
+ }
+ public boolean isMergable(AggregateInfo that) {
+ //check independence
+ boolean ret = _aggregates.size()<3
+ && _aggregates.size()+that._aggregates.size()<=3;
+ for( Long hopID : that._aggregates.keySet() )
+ ret &= !_inputAggs.contains(hopID);
+ for( Long hopID : _aggregates.keySet() )
+ ret &= !that._inputAggs.contains(hopID);
+ //check partial shared reads
+ ret &= !CollectionUtils.intersection(
+ _fusedInputs, that._fusedInputs).isEmpty();
+ //check consistent sizes (result correctness)
+ Hop in1 = _aggregates.values().iterator().next();
+ Hop in2 = that._aggregates.values().iterator().next();
+ return ret && HopRewriteUtils.isEqualSize(
+ in1.getInput().get(HopRewriteUtils.isMatrixMultiply(in1)?1:0),
+ in2.getInput().get(HopRewriteUtils.isMatrixMultiply(in2)?1:0));
+ }
+ public AggregateInfo merge(AggregateInfo that) {
+ _aggregates.putAll(that._aggregates);
+ _inputAggs.addAll(that._inputAggs);
+ _fusedInputs.addAll(that._fusedInputs);
+ return this;
+ }
+ @Override
+ public String toString() {
+ return "["+Arrays.toString(_aggregates.keySet().toArray(new Long[0]))+": "
+ +"{"+Arrays.toString(_inputAggs.toArray(new Long[0]))+"},"
+ +"{"+Arrays.toString(_fusedInputs.toArray(new Long[0]))+"}]";
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java
new file mode 100644
index 0000000..759a903
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseNoRedundancy.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Map.Entry;
+import java.util.HashSet;
+import java.util.List;
+
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
+import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
+
+/**
+ * This plan selection heuristic aims for fusion without any redundant
+ * computation, which, however, potentially leads to more materialized
+ * intermediates than the fuse all heuristic.
+ * <p>
+ * NOTE: This heuristic is essentially the same as FuseAll, except that
+ * any plans that refer to a hop with multiple consumers are removed in
+ * a pre-processing step.
+ *
+ */
+public class PlanSelectionFuseNoRedundancy extends PlanSelection
+{
+ @Override
+ public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
+ //pruning and collection pass
+ for( Hop hop : roots )
+ rSelectPlans(memo, hop, null);
+
+ //take all distinct best plans
+ for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() )
+ memo.setDistinct(e.getKey(), e.getValue());
+ }
+
+ private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType)
+ {
+ if( isVisited(current.getHopID(), currentType) )
+ return;
+
+ //step 0: remove plans that refer to a common partial plan
+ if( memo.contains(current.getHopID()) ) {
+ HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
+ List<MemoTableEntry> hopP = memo.get(current.getHopID());
+ for( MemoTableEntry e1 : hopP )
+ for( int i=0; i<3; i++ )
+ if( e1.isPlanRef(i) && current.getInput().get(i).getParent().size()>1 )
+ rmSet.add(e1); //remove references to hops w/ multiple consumers
+ memo.remove(current, rmSet);
+ }
+
+ //step 1: prune subsumed plans of same type
+ if( memo.contains(current.getHopID()) ) {
+ HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
+ List<MemoTableEntry> hopP = memo.get(current.getHopID());
+ for( MemoTableEntry e1 : hopP )
+ for( MemoTableEntry e2 : hopP )
+ if( e1 != e2 && e1.subsumes(e2) )
+ rmSet.add(e2);
+ memo.remove(current, rmSet);
+ }
+
+ //step 2: select plan for current path
+ MemoTableEntry best = null;
+ if( memo.contains(current.getHopID()) ) {
+ if( currentType == null ) {
+ best = memo.get(current.getHopID()).stream()
+ .filter(p -> isValid(p, current))
+ .min(new BasicPlanComparator()).orElse(null);
+ }
+ else {
+ best = memo.get(current.getHopID()).stream()
+ .filter(p -> p.type==currentType || p.type==TemplateType.CELL)
+ .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs()))
+ .orElse(null);
+ }
+ addBestPlan(current.getHopID(), best);
+ }
+
+ //step 3: recursively process children
+ for( int i=0; i< current.getInput().size(); i++ ) {
+ TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null;
+ rSelectPlans(memo, current.getInput().get(i), pref);
+ }
+
+ setVisited(current.getHopID(), currentType);
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java b/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java
new file mode 100644
index 0000000..de1ed92
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/ReachabilityGraph.java
@@ -0,0 +1,398 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.codegen.opt;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
+import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysml.hops.codegen.opt.PlanSelection.VisitMarkCost;
+
+/**
+ *
+ */
+public class ReachabilityGraph
+{
+ private HashMap<Pair<Long,Long>,NodeLink> _matPoints = null;
+ private NodeLink _root = null;
+
+ private InterestingPoint[] _searchSpace;
+ private CutSet[] _cutSets;
+
+ public ReachabilityGraph(PlanPartition part, CPlanMemoTable memo) {
+ //create repository of materialization points
+ _matPoints = new HashMap<>();
+ for( InterestingPoint p : part.getMatPointsExt() )
+ _matPoints.put(Pair.of(p._fromHopID, p._toHopID), new NodeLink(p));
+
+ //create reachability graph
+ _root = new NodeLink(null);
+ HashSet<VisitMarkCost> visited = new HashSet<>();
+ for( Long hopID : part.getRoots() ) {
+ Hop rootHop = memo.getHopRefs().get(hopID);
+ addInputNodeLinks(rootHop, _root, part, memo, visited);
+ }
+
+ //create candidate cutsets
+ List<NodeLink> tmpCS = _matPoints.values().stream()
+ .filter(p -> p._inputs.size() > 0 && p._p != null)
+ .sorted().collect(Collectors.toList());
+
+ //short-cut for partitions without cutsets
+ if( tmpCS.isEmpty() ) {
+ _cutSets = new CutSet[0];
+ _searchSpace = part.getMatPointsExt();
+ return;
+ }
+
+ //create composite cutsets
+ ArrayList<ArrayList<NodeLink>> candCS = new ArrayList<>();
+ ArrayList<NodeLink> current = new ArrayList<>();
+ for( NodeLink node : tmpCS ) {
+ if( current.isEmpty() )
+ current.add(node);
+ else if( current.get(0).equals(node) )
+ current.add(node);
+ else {
+ candCS.add(current);
+ current = new ArrayList<>();
+ current.add(node);
+ }
+ }
+ if( !current.isEmpty() )
+ candCS.add(current);
+
+ //evaluate cutsets (single, and duplicate pairs)
+ ArrayList<ArrayList<NodeLink>> remain = new ArrayList<>();
+ ArrayList<Pair<CutSet,Double>> cutSets = evaluateCutSets(candCS, remain);
+ if( !remain.isEmpty() && remain.size() < 5 ) {
+ //second chance: for pairs for remaining candidates
+ ArrayList<ArrayList<NodeLink>> candCS2 = new ArrayList<>();
+ for( int i=0; i<remain.size()-1; i++)
+ for( int j=i+1; j<remain.size(); j++) {
+ ArrayList<NodeLink> tmp = new ArrayList<>();
+ tmp.addAll(remain.get(i));
+ tmp.addAll(remain.get(j));
+ candCS2.add(tmp);
+ }
+ ArrayList<Pair<CutSet,Double>> cutSets2 = evaluateCutSets(candCS2, remain);
+ //ensure constructed cutsets are disjoint
+ HashSet<InterestingPoint> testDisjoint = new HashSet<>();
+ for( Pair<CutSet,Double> cs : cutSets2 ) {
+ if( !CollectionUtils.containsAny(testDisjoint, Arrays.asList(cs.getLeft().cut)) ) {
+ cutSets.add(cs);
+ CollectionUtils.addAll(testDisjoint, cs.getLeft().cut);
+ }
+ }
+ }
+
+ //sort and linearize search space according to scores
+ _cutSets = cutSets.stream()
+ .sorted(Comparator.comparing(p -> p.getRight()))
+ .map(p -> p.getLeft()).toArray(CutSet[]::new);
+
+ HashMap<InterestingPoint, Integer> probe = new HashMap<>();
+ ArrayList<InterestingPoint> lsearchSpace = new ArrayList<>();
+ for( CutSet cs : _cutSets ) {
+ cs.updatePos(lsearchSpace.size());
+ cs.updatePartitions(probe);
+ CollectionUtils.addAll(lsearchSpace, cs.cut);
+ for( InterestingPoint p: cs.cut )
+ probe.put(p, probe.size()-1);
+ }
+ for( InterestingPoint p : part.getMatPointsExt() )
+ if( !probe.containsKey(p) ) {
+ lsearchSpace.add(p);
+ probe.put(p, probe.size()-1);
+ }
+ _searchSpace = lsearchSpace.toArray(new InterestingPoint[0]);
+
+ //materialize partition indices
+ for( CutSet cs : _cutSets ) {
+ cs.updatePartitionIndexes(probe);
+ cs.finalizePartition();
+ }
+
+ //final sanity check of interesting points
+ if( _searchSpace.length != part.getMatPointsExt().length )
+ throw new RuntimeException("Corrupt linearized search space: " +
+ _searchSpace.length+" vs "+part.getMatPointsExt().length);
+ }
+
+ public InterestingPoint[] getSortedSearchSpace() {
+ return _searchSpace;
+ }
+
+ public boolean isCutSet(boolean[] plan) {
+ for( CutSet cs : _cutSets )
+ if( isCutSet(cs, plan) )
+ return true;
+ return false;
+ }
+
+ public boolean isCutSet(CutSet cs, boolean[] plan) {
+ boolean ret = true;
+ for(int i=0; i<cs.posCut.length && ret; i++)
+ ret &= plan[cs.posCut[i]];
+ return ret;
+ }
+
+ public CutSet getCutSet(boolean[] plan) {
+ for( CutSet cs : _cutSets )
+ if( isCutSet(cs, plan) )
+ return cs;
+ throw new RuntimeException("No valid cut set found.");
+ }
+
+ public long getNumSkipPlans(boolean[] plan) {
+ for( CutSet cs : _cutSets )
+ if( isCutSet(cs, plan) ) {
+ int pos = cs.posCut[cs.posCut.length-1];
+ return (long) Math.pow(2, plan.length-pos-1);
+ }
+ throw new RuntimeException("Failed to compute "
+ + "number of skip plans for plan without cutset.");
+ }
+
+
+ public SubProblem[] getSubproblems(boolean[] plan) {
+ CutSet cs = getCutSet(plan);
+ return new SubProblem[] {
+ new SubProblem(cs.cut.length, cs.posLeft, cs.left),
+ new SubProblem(cs.cut.length, cs.posRight, cs.right)};
+ }
+
+ @Override
+ public String toString() {
+ return "ReachabilityGraph("+_matPoints.size()+"):\n"
+ + _root.explain(new HashSet<>());
+ }
+
+ private void addInputNodeLinks(Hop current, NodeLink parent, PlanPartition part,
+ CPlanMemoTable memo, HashSet<VisitMarkCost> visited)
+ {
+ if( visited.contains(new VisitMarkCost(current.getHopID(), parent._ID)) )
+ return;
+
+ //process children
+ for( Hop in : current.getInput() ) {
+ if( InterestingPoint.isMatPoint(part.getMatPointsExt(), current.getHopID(), in.getHopID()) ) {
+ NodeLink tmp = _matPoints.get(Pair.of(current.getHopID(), in.getHopID()));
+ parent.addInput(tmp);
+ addInputNodeLinks(in, tmp, part, memo, visited);
+ }
+ else
+ addInputNodeLinks(in, parent, part, memo, visited);
+ }
+
+ visited.add(new VisitMarkCost(current.getHopID(), parent._ID));
+ }
+
+ private void rCollectInputs(NodeLink current, HashSet<NodeLink> probe, HashSet<NodeLink> inputs) {
+ for( NodeLink c : current._inputs )
+ if( !probe.contains(c) ) {
+ rCollectInputs(c, probe, inputs);
+ inputs.add(c);
+ }
+ }
+
+ private ArrayList<Pair<CutSet,Double>> evaluateCutSets(ArrayList<ArrayList<NodeLink>> candCS, ArrayList<ArrayList<NodeLink>> remain) {
+ ArrayList<Pair<CutSet,Double>> cutSets = new ArrayList<>();
+
+ for( ArrayList<NodeLink> cand : candCS ) {
+ HashSet<NodeLink> probe = new HashSet<>(cand);
+
+ //determine subproblems for cutset candidates
+ HashSet<NodeLink> part1 = new HashSet<>();
+ rCollectInputs(_root, probe, part1);
+ HashSet<NodeLink> part2 = new HashSet<>();
+ for( NodeLink rNode : cand )
+ rCollectInputs(rNode, probe, part2);
+
+ //select, score and create cutsets
+ if( !CollectionUtils.containsAny(part1, part2)
+ && !part1.isEmpty() && !part2.isEmpty()) {
+ //score cutsets (smaller is better)
+ double base = Math.pow(2, _matPoints.size());
+ double numComb = Math.pow(2, cand.size());
+ double score = (numComb-1)/numComb * base
+ + 1/numComb * Math.pow(2, part1.size())
+ + 1/numComb * Math.pow(2, part2.size());
+
+ //construct cutset
+ cutSets.add(Pair.of(new CutSet(
+ cand.stream().map(p->p._p).toArray(InterestingPoint[]::new),
+ part1.stream().map(p->p._p).toArray(InterestingPoint[]::new),
+ part2.stream().map(p->p._p).toArray(InterestingPoint[]::new)), score));
+ }
+ else {
+ remain.add(cand);
+ }
+ }
+
+ return cutSets;
+ }
+
+ public static class SubProblem {
+ public int offset;
+ public int[] freePos;
+ public InterestingPoint[] freeMat;
+
+ public SubProblem(int off, int[] pos, InterestingPoint[] mat) {
+ offset = off;
+ freePos = pos;
+ freeMat = mat;
+ }
+ }
+
+ public static class CutSet {
+ public InterestingPoint[] cut;
+ public InterestingPoint[] left;
+ public InterestingPoint[] right;
+ public int[] posCut;
+ public int[] posLeft;
+ public int[] posRight;
+
+ public CutSet(InterestingPoint[] cutPoints,
+ InterestingPoint[] l, InterestingPoint[] r) {
+ cut = cutPoints;
+ left = l;
+ right = r;
+ }
+
+ public void updatePos(int index) {
+ posCut = new int[cut.length];
+ for(int i=0; i<posCut.length; i++)
+ posCut[i] = index + i;
+ }
+
+ public void updatePartitions(HashMap<InterestingPoint,Integer> blacklist) {
+ left = Arrays.stream(left).filter(p -> !blacklist.containsKey(p))
+ .toArray(InterestingPoint[]::new);
+ right = Arrays.stream(right).filter(p -> !blacklist.containsKey(p))
+ .toArray(InterestingPoint[]::new);
+ }
+
+ public void updatePartitionIndexes(HashMap<InterestingPoint,Integer> probe) {
+ posLeft = new int[left.length];
+ for(int i=0; i<left.length; i++)
+ posLeft[i] = probe.get(left[i]);
+ posRight = new int[right.length];
+ for(int i=0; i<right.length; i++)
+ posRight[i] = probe.get(right[i]);
+ }
+
+ public void finalizePartition() {
+ left = (InterestingPoint[]) ArrayUtils.addAll(cut, left);
+ right = (InterestingPoint[]) ArrayUtils.addAll(cut, right);
+ }
+
+ @Override
+ public String toString() {
+ return "Cut : "+Arrays.toString(cut);
+ }
+ }
+
+ private static class NodeLink implements Comparable<NodeLink>
+ {
+ private static final IDSequence _seqID = new IDSequence();
+
+ private ArrayList<NodeLink> _inputs = new ArrayList<>();
+ private long _ID;
+ private InterestingPoint _p;
+
+ public NodeLink(InterestingPoint p) {
+ _ID = _seqID.getNextID();
+ _p = p;
+ }
+
+ public void addInput(NodeLink in) {
+ _inputs.add(in);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if( !(o instanceof NodeLink) )
+ return false;
+ NodeLink that = (NodeLink) o;
+ boolean ret = (_inputs.size() == that._inputs.size());
+ for( int i=0; i<_inputs.size() && ret; i++ )
+ ret &= (_inputs.get(i)._ID == that._inputs.get(i)._ID);
+ return ret;
+ }
+
+ @Override
+ public int compareTo(NodeLink that) {
+ if( _inputs.size() > that._inputs.size() )
+ return -1;
+ else if( _inputs.size() < that._inputs.size() )
+ return 1;
+ for( int i=0; i<_inputs.size(); i++ ) {
+ int comp = Long.compare(_inputs.get(i)._ID,
+ that._inputs.get(i)._ID);
+ if( comp != 0 )
+ return comp;
+ }
+ return 0;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder inputs = new StringBuilder();
+ for(NodeLink in : _inputs) {
+ if( inputs.length() > 0 )
+ inputs.append(",");
+ inputs.append(in._ID);
+ }
+ return _ID+" ("+inputs.toString()+") "+((_p!=null)?_p:"null");
+ }
+
+ private String explain(HashSet<Long> visited) {
+ if( visited.contains(_ID) )
+ return "";
+ //add children
+ StringBuilder sb = new StringBuilder();
+ StringBuilder inputs = new StringBuilder();
+ for(NodeLink in : _inputs) {
+ String tmp = in.explain(visited);
+ if( !tmp.isEmpty() )
+ sb.append(tmp + "\n");
+ if( inputs.length() > 0 )
+ inputs.append(",");
+ inputs.append(in._ID);
+ }
+ //add node itself
+ sb.append(_ID+" ("+inputs+") "+((_p!=null)?_p:"null"));
+ visited.add(_ID);
+
+ return sb.toString();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
index edbcdf9..4078060 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.codegen.template;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
@@ -36,6 +37,8 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.codegen.SpoofCompiler;
+import org.apache.sysml.hops.codegen.opt.InterestingPoint;
+import org.apache.sysml.hops.codegen.opt.PlanSelection;
import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
import org.apache.sysml.runtime.util.UtilFunctions;
@@ -53,6 +56,18 @@ public class CPlanMemoTable
_plansBlacklist = new HashSet<Long>();
}
+ public HashMap<Long, List<MemoTableEntry>> getPlans() {
+ return _plans;
+ }
+
+ public HashSet<Long> getPlansBlacklisted() {
+ return _plansBlacklist;
+ }
+
+ public HashMap<Long, Hop> getHopRefs() {
+ return _hopRefs;
+ }
+
public void addHop(Hop hop) {
_hopRefs.put(hop.getHopID(), hop);
}
@@ -78,6 +93,14 @@ public class CPlanMemoTable
.anyMatch(p -> (!checkClose||!p.closed) && probe.contains(p.type));
}
+ public boolean containsNotIn(long hopID, Collection<TemplateType> types,
+ boolean checkChildRefs, boolean excludeCell) {
+ return contains(hopID) && get(hopID).stream()
+ .anyMatch(p -> (!checkChildRefs || p.hasPlanRef())
+ && (!excludeCell || p.type!=TemplateType.CELL)
+ && !types.contains(p.type));
+ }
+
public int countEntries(long hopID) {
return get(hopID).size();
}
@@ -85,7 +108,7 @@ public class CPlanMemoTable
public int countEntries(long hopID, TemplateType type) {
return (int) get(hopID).stream()
.filter(p -> p.type==type).count();
- }
+ }
public boolean containsTopLevel(long hopID) {
return !_plansBlacklist.contains(hopID)
@@ -133,7 +156,7 @@ public class CPlanMemoTable
.distinct().collect(Collectors.toList()));
}
- public void pruneRedundant(long hopID) {
+ public void pruneRedundant(long hopID, boolean pruneDominated, InterestingPoint[] matPoints) {
if( !contains(hopID) )
return;
@@ -146,7 +169,7 @@ public class CPlanMemoTable
//prune dominated plans (e.g., opened plan subsumed by fused plan
//if single consumer of input; however this only applies to fusion
//heuristic that only consider materialization points)
- if( SpoofCompiler.PLAN_SEL_POLICY.isHeuristic() ) {
+ if( pruneDominated ) {
HashSet<MemoTableEntry> rmList = new HashSet<MemoTableEntry>();
List<MemoTableEntry> list = _plans.get(hopID);
Hop hop = _hopRefs.get(hopID);
@@ -155,9 +178,12 @@ public class CPlanMemoTable
if( e1 != e2 && e1.subsumes(e2) ) {
//check that childs don't have multiple consumers
boolean rmSafe = true;
- for( int i=0; i<=2; i++ )
+ for( int i=0; i<=2; i++ ) {
rmSafe &= (e1.isPlanRef(i) && !e2.isPlanRef(i)) ?
- hop.getInput().get(i).getParent().size()==1 : true;
+ (matPoints!=null && !InterestingPoint.isMatPoint(
+ matPoints, hopID, e1.input(i)))
+ || hop.getInput().get(i).getParent().size()==1 : true;
+ }
if( rmSafe )
rmList.add(e2);
}
@@ -194,12 +220,14 @@ public class CPlanMemoTable
//prune dominated plans (e.g., plan referenced by other plan and this
//other plan is single consumer) by marking it as blacklisted because
//the chain of entries is still required for cplan construction
- for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() )
- for( MemoTableEntry me : e.getValue() ) {
- for( int i=0; i<=2; i++ )
- if( me.isPlanRef(i) && _hopRefs.get(me.input(i)).getParent().size()==1 )
- _plansBlacklist.add(me.input(i));
- }
+ if( SpoofCompiler.PLAN_SEL_POLICY.isHeuristic() ) {
+ for( Entry<Long, List<MemoTableEntry>> e : _plans.entrySet() )
+ for( MemoTableEntry me : e.getValue() ) {
+ for( int i=0; i<=2; i++ )
+ if( me.isPlanRef(i) && _hopRefs.get(me.input(i)).getParent().size()==1 )
+ _plansBlacklist.add(me.input(i));
+ }
+ }
//core plan selection
PlanSelection selector = SpoofCompiler.createPlanSelector();
@@ -232,6 +260,16 @@ public class CPlanMemoTable
.distinct().collect(Collectors.toList());
}
+ public List<TemplateType> getDistinctTemplateTypes(long hopID, int refAt) {
+ if(!contains(hopID))
+ return Collections.emptyList();
+ //return distinct template types with reference at given position
+ return _plans.get(hopID).stream()
+ .filter(p -> p.isPlanRef(refAt))
+ .map(p -> p.type) //extract type
+ .distinct().collect(Collectors.toList());
+ }
+
public MemoTableEntry getBest(long hopID) {
List<MemoTableEntry> tmp = get(hopID);
if( tmp == null || tmp.isEmpty() )
http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java
deleted file mode 100644
index f8a12fd..0000000
--- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelection.java
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.hops.codegen.template;
-
-import java.util.ArrayList;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-
-import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
-import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
-import org.apache.sysml.hops.rewrite.HopRewriteUtils;
-import org.apache.sysml.runtime.util.UtilFunctions;
-
-public abstract class PlanSelection
-{
- private final HashMap<Long, List<MemoTableEntry>> _bestPlans =
- new HashMap<Long, List<MemoTableEntry>>();
- private final HashSet<VisitMark> _visited = new HashSet<VisitMark>();
-
- /**
- * Given a HOP DAG G, and a set of partial fusions plans P, find the set of optimal,
- * non-conflicting fusion plans P' that applied to G minimizes costs C with
- * P' = \argmin_{p \subseteq P} C(G, p) s.t. Z \vDash p, where Z is a set of
- * constraints such as memory budgets and block size restrictions per fused operator.
- *
- * @param memo partial fusion plans P
- * @param roots entry points of HOP DAG G
- */
- public abstract void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots);
-
- /**
- * Determines if the given partial fusion plan is valid.
- *
- * @param me memo table entry
- * @param hop current hop
- * @return true if entry is valid as top-level plan
- */
- protected static boolean isValid(MemoTableEntry me, Hop hop) {
- return (me.type == TemplateType.OuterProdTpl
- && (me.closed || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop)))
- || (me.type == TemplateType.RowTpl)
- || (me.type == TemplateType.CellTpl)
- || (me.type == TemplateType.MultiAggTpl);
- }
-
- protected void addBestPlan(long hopID, MemoTableEntry me) {
- if( me == null ) return;
- if( !_bestPlans.containsKey(hopID) )
- _bestPlans.put(hopID, new ArrayList<MemoTableEntry>());
- _bestPlans.get(hopID).add(me);
- }
-
- protected HashMap<Long, List<MemoTableEntry>> getBestPlans() {
- return _bestPlans;
- }
-
- protected boolean isVisited(long hopID, TemplateType type) {
- return _visited.contains(new VisitMark(hopID, type));
- }
-
- protected void setVisited(long hopID, TemplateType type) {
- _visited.add(new VisitMark(hopID, type));
- }
-
- /**
- * Basic plan comparator to compare memo table entries with regard to
- * a pre-defined template preference order and the number of references.
- */
- protected static class BasicPlanComparator implements Comparator<MemoTableEntry> {
- @Override
- public int compare(MemoTableEntry o1, MemoTableEntry o2) {
- //for different types, select preferred type
- if( o1.type != o2.type )
- return Integer.compare(o1.type.getRank(), o2.type.getRank());
-
- //for same type, prefer plan with more refs
- return Integer.compare(
- 3-o1.countPlanRefs(), 3-o2.countPlanRefs());
- }
- }
-
- private static class VisitMark {
- private final long _hopID;
- private final TemplateType _type;
-
- public VisitMark(long hopID, TemplateType type) {
- _hopID = hopID;
- _type = type;
- }
- @Override
- public int hashCode() {
- return UtilFunctions.longHashCode(
- _hopID, (_type!=null)?_type.hashCode():0);
- }
- @Override
- public boolean equals(Object o) {
- return (o instanceof VisitMark
- && _hopID == ((VisitMark)o)._hopID
- && _type == ((VisitMark)o)._type);
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/7b4a3418/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java
deleted file mode 100644
index a455302..0000000
--- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseAll.java
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.hops.codegen.template;
-
-import java.util.ArrayList;
-import java.util.Comparator;
-import java.util.Map.Entry;
-import java.util.HashSet;
-import java.util.List;
-
-import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
-import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
-
-/**
- * This plan selection heuristic aims for maximal fusion, which
- * potentially leads to overlapping fused operators and thus,
- * redundant computation but with a minimal number of materialized
- * intermediate results.
- *
- */
-public class PlanSelectionFuseAll extends PlanSelection
-{
- @Override
- public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
- //pruning and collection pass
- for( Hop hop : roots )
- rSelectPlans(memo, hop, null);
-
- //take all distinct best plans
- for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() )
- memo.setDistinct(e.getKey(), e.getValue());
- }
-
- private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType)
- {
- if( isVisited(current.getHopID(), currentType) )
- return;
-
- //step 1: prune subsumed plans of same type
- if( memo.contains(current.getHopID()) ) {
- HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
- List<MemoTableEntry> hopP = memo.get(current.getHopID());
- for( MemoTableEntry e1 : hopP )
- for( MemoTableEntry e2 : hopP )
- if( e1 != e2 && e1.subsumes(e2) )
- rmSet.add(e2);
- memo.remove(current, rmSet);
- }
-
- //step 2: select plan for current path
- MemoTableEntry best = null;
- if( memo.contains(current.getHopID()) ) {
- if( currentType == null ) {
- best = memo.get(current.getHopID()).stream()
- .filter(p -> isValid(p, current))
- .min(new BasicPlanComparator()).orElse(null);
- }
- else {
- best = memo.get(current.getHopID()).stream()
- .filter(p -> p.type==currentType || p.type==TemplateType.CellTpl)
- .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs()))
- .orElse(null);
- }
- addBestPlan(current.getHopID(), best);
- }
-
- //step 3: recursively process children
- for( int i=0; i< current.getInput().size(); i++ ) {
- TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null;
- rSelectPlans(memo, current.getInput().get(i), pref);
- }
-
- setVisited(current.getHopID(), currentType);
- }
-}