You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/10/02 07:39:19 UTC
[2/2] systemml git commit: [SYSTEMML-1943] Fix codegen fuse_all
optimizer and consolidation
[SYSTEMML-1943] Fix codegen fuse_all optimizer and consolidation
This patch fixes special cases of row operations that caused the
fuse_all optimizer fail on Kmeans. Furthermore, this also includes a
cleanup for consolidating the fuse-all selection of plans as used in the
fuse_all and both cost-based optimizers.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c27c488b
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c27c488b
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c27c488b
Branch: refs/heads/master
Commit: c27c488bef54887d549792c4cf6532d95c3f5c58
Parents: 8ed2516
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun Oct 1 20:04:39 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Mon Oct 2 00:39:21 2017 -0700
----------------------------------------------------------------------
.../java/org/apache/sysml/conf/DMLConfig.java | 2 +-
.../apache/sysml/hops/codegen/SpoofFusedOp.java | 11 +++++
.../sysml/hops/codegen/cplan/CNodeRow.java | 3 +-
.../sysml/hops/codegen/opt/PlanSelection.java | 46 +++++++++++++++++++
.../hops/codegen/opt/PlanSelectionFuseAll.java | 47 +-------------------
.../codegen/opt/PlanSelectionFuseCostBased.java | 45 +------------------
.../opt/PlanSelectionFuseCostBasedV2.java | 47 +-------------------
.../hops/codegen/template/TemplateUtils.java | 2 +
.../sysml/runtime/codegen/SpoofRowwise.java | 25 ++++++-----
9 files changed, 79 insertions(+), 149 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/conf/DMLConfig.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/conf/DMLConfig.java b/src/main/java/org/apache/sysml/conf/DMLConfig.java
index 6a331a6..9835b4d 100644
--- a/src/main/java/org/apache/sysml/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysml/conf/DMLConfig.java
@@ -127,7 +127,7 @@ public class DMLConfig
_defaultVals.put(COMPRESSED_LINALG, Compression.CompressConfig.AUTO.name() );
_defaultVals.put(CODEGEN, "false" );
_defaultVals.put(CODEGEN_COMPILER, CompilerType.AUTO.name() );
- _defaultVals.put(CODEGEN_COMPILER, PlanSelector.FUSE_COST_BASED_V2.name() );
+ _defaultVals.put(CODEGEN_OPTIMIZER, PlanSelector.FUSE_COST_BASED_V2.name() );
_defaultVals.put(CODEGEN_PLANCACHE, "true" );
_defaultVals.put(CODEGEN_LITERALS, "1" );
_defaultVals.put(NATIVE_BLAS, "none" );
http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
index 81b226d..56bfb61 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
@@ -42,6 +42,7 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop
ROW_DIMS,
COLUMN_DIMS_ROWS,
COLUMN_DIMS_COLS,
+ RANK_DIMS_COLS,
SCALAR,
MULTI_SCALAR,
ROW_RANK_DIMS, // right wdivmm, row mm
@@ -163,6 +164,12 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop
case COLUMN_DIMS_COLS:
ret = new long[]{1, mc.getCols(), -1};
break;
+ case RANK_DIMS_COLS: {
+ MatrixCharacteristics mc2 = memo.getAllInputStats(getInput().get(1));
+ if( mc2.dimsKnown() )
+ ret = new long[]{1, mc2.getCols(), -1};
+ break;
+ }
case INPUT_DIMS:
ret = new long[]{mc.getRows(), mc.getCols(), -1};
break;
@@ -219,6 +226,10 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop
setDim1(1);
setDim2(getInput().get(0).getDim2());
break;
+ case RANK_DIMS_COLS:
+ setDim1(1);
+ setDim2(getInput().get(1).getDim2());
+ break;
case INPUT_DIMS:
setDim1(getInput().get(0).getDim1());
setDim2(getInput().get(0).getDim2());
http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
index 07822d9..9235216 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
@@ -158,7 +158,8 @@ public class CNodeRow extends CNodeTpl
case COL_AGG: return SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector
case COL_AGG_T: return SpoofOutputDimsType.COLUMN_DIMS_ROWS; //column vector
case COL_AGG_B1: return SpoofOutputDimsType.COLUMN_RANK_DIMS;
- case COL_AGG_B1_T: return SpoofOutputDimsType.COLUMN_RANK_DIMS_T;
+ case COL_AGG_B1_T: return SpoofOutputDimsType.COLUMN_RANK_DIMS_T;
+ case COL_AGG_B1R: return SpoofOutputDimsType.RANK_DIMS_COLS;
default:
throw new RuntimeException("Unsupported row type: "+_type.toString());
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
index 21f4fd3..4cf56c4 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
@@ -34,6 +34,9 @@ import org.apache.sysml.runtime.util.UtilFunctions;
public abstract class PlanSelection
{
+ private static final BasicPlanComparator BASE_COMPARE = new BasicPlanComparator();
+ private final TypedPlanComparator _typedCompare = new TypedPlanComparator();
+
private final HashMap<Long, List<MemoTableEntry>> _bestPlans =
new HashMap<Long, List<MemoTableEntry>>();
private final HashSet<VisitMark> _visited = new HashSet<VisitMark>();
@@ -84,6 +87,49 @@ public abstract class PlanSelection
_visited.add(new VisitMark(hopID, type));
}
+ protected void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateType currentType, HashSet<Long> partition)
+ {
+ if( isVisited(current.getHopID(), currentType)
+ || (partition!=null && !partition.contains(current.getHopID())) )
+ return;
+
+ //step 1: prune subsumed plans of same type
+ if( memo.contains(current.getHopID()) ) {
+ HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
+ List<MemoTableEntry> hopP = memo.get(current.getHopID());
+ for( MemoTableEntry e1 : hopP )
+ for( MemoTableEntry e2 : hopP )
+ if( e1 != e2 && e1.subsumes(e2) )
+ rmSet.add(e2);
+ memo.remove(current, rmSet);
+ }
+
+ //step 2: select plan for current path
+ MemoTableEntry best = null;
+ if( memo.contains(current.getHopID()) ) {
+ if( currentType == null ) {
+ best = memo.get(current.getHopID()).stream()
+ .filter(p -> isValid(p, current))
+ .min(BASE_COMPARE).orElse(null);
+ }
+ else {
+ _typedCompare.setType(currentType);
+ best = memo.get(current.getHopID()).stream()
+ .filter(p -> p.type==currentType || p.type==TemplateType.CELL)
+ .min(_typedCompare).orElse(null);
+ }
+ addBestPlan(current.getHopID(), best);
+ }
+
+ //step 3: recursively process children
+ for( int i=0; i< current.getInput().size(); i++ ) {
+ TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null;
+ rSelectPlansFuseAll(memo, current.getInput().get(i), pref, partition);
+ }
+
+ setVisited(current.getHopID(), currentType);
+ }
+
/**
* Basic plan comparator to compare memo table entries with regard to
* a pre-defined template preference order and the number of references.
http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java
index 8636bea..3e0561d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseAll.java
@@ -20,15 +20,12 @@
package org.apache.sysml.hops.codegen.opt;
import java.util.ArrayList;
-import java.util.Comparator;
import java.util.Map.Entry;
-import java.util.HashSet;
import java.util.List;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable;
import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntry;
-import org.apache.sysml.hops.codegen.template.TemplateBase.TemplateType;
/**
* This plan selection heuristic aims for maximal fusion, which
@@ -43,52 +40,10 @@ public class PlanSelectionFuseAll extends PlanSelection
public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots) {
//pruning and collection pass
for( Hop hop : roots )
- rSelectPlans(memo, hop, null);
+ rSelectPlansFuseAll(memo, hop, null, null);
//take all distinct best plans
for( Entry<Long, List<MemoTableEntry>> e : getBestPlans().entrySet() )
memo.setDistinct(e.getKey(), e.getValue());
}
-
- private void rSelectPlans(CPlanMemoTable memo, Hop current, TemplateType currentType)
- {
- if( isVisited(current.getHopID(), currentType) )
- return;
-
- //step 1: prune subsumed plans of same type
- if( memo.contains(current.getHopID()) ) {
- HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
- List<MemoTableEntry> hopP = memo.get(current.getHopID());
- for( MemoTableEntry e1 : hopP )
- for( MemoTableEntry e2 : hopP )
- if( e1 != e2 && e1.subsumes(e2) )
- rmSet.add(e2);
- memo.remove(current, rmSet);
- }
-
- //step 2: select plan for current path
- MemoTableEntry best = null;
- if( memo.contains(current.getHopID()) ) {
- if( currentType == null ) {
- best = memo.get(current.getHopID()).stream()
- .filter(p -> isValid(p, current))
- .min(new BasicPlanComparator()).orElse(null);
- }
- else {
- best = memo.get(current.getHopID()).stream()
- .filter(p -> p.type==currentType || p.type==TemplateType.CELL)
- .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs()))
- .orElse(null);
- }
- addBestPlan(current.getHopID(), best);
- }
-
- //step 3: recursively process children
- for( int i=0; i< current.getInput().size(); i++ ) {
- TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null;
- rSelectPlans(memo, current.getInput().get(i), pref);
- }
-
- setVisited(current.getHopID(), currentType);
- }
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
index acb90e2..f67604d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBased.java
@@ -507,52 +507,9 @@ public class PlanSelectionFuseCostBased extends PlanSelection
}
}
- visited.add(current.getHopID());
+ visited.add(current.getHopID());
}
- private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateType currentType, HashSet<Long> partition)
- {
- if( isVisited(current.getHopID(), currentType)
- || !partition.contains(current.getHopID()) )
- return;
-
- //step 1: prune subsumed plans of same type
- if( memo.contains(current.getHopID()) ) {
- HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
- List<MemoTableEntry> hopP = memo.get(current.getHopID());
- for( MemoTableEntry e1 : hopP )
- for( MemoTableEntry e2 : hopP )
- if( e1 != e2 && e1.subsumes(e2) )
- rmSet.add(e2);
- memo.remove(current, rmSet);
- }
-
- //step 2: select plan for current path
- MemoTableEntry best = null;
- if( memo.contains(current.getHopID()) ) {
- if( currentType == null ) {
- best = memo.get(current.getHopID()).stream()
- .filter(p -> isValid(p, current))
- .min(new BasicPlanComparator()).orElse(null);
- }
- else {
- best = memo.get(current.getHopID()).stream()
- .filter(p -> p.type==currentType || p.type==TemplateType.CELL)
- .min(Comparator.comparing(p -> 7-((p.type==currentType)?4:0)-p.countPlanRefs()))
- .orElse(null);
- }
- addBestPlan(current.getHopID(), best);
- }
-
- //step 3: recursively process children
- for( int i=0; i< current.getInput().size(); i++ ) {
- TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null;
- rSelectPlansFuseAll(memo, current.getInput().get(i), pref, partition);
- }
-
- setVisited(current.getHopID(), currentType);
- }
-
private static boolean[] createAssignment(int len, int pos) {
boolean[] ret = new boolean[len];
int tmp = pos;
http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 8d1c4c0..31e8427 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -98,8 +98,6 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
private static final IDSequence COST_ID = new IDSequence();
private static final TemplateRow ROW_TPL = new TemplateRow();
- private static final BasicPlanComparator BASE_COMPARE = new BasicPlanComparator();
- private final TypedPlanComparator _typedCompare = new TypedPlanComparator();
@Override
public void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots)
@@ -726,50 +724,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
}
}
- visited.add(current.getHopID());
- }
-
- private void rSelectPlansFuseAll(CPlanMemoTable memo, Hop current, TemplateType currentType, HashSet<Long> partition)
- {
- if( isVisited(current.getHopID(), currentType)
- || !partition.contains(current.getHopID()) )
- return;
-
- //step 1: prune subsumed plans of same type
- if( memo.contains(current.getHopID()) ) {
- HashSet<MemoTableEntry> rmSet = new HashSet<MemoTableEntry>();
- List<MemoTableEntry> hopP = memo.get(current.getHopID());
- for( MemoTableEntry e1 : hopP )
- for( MemoTableEntry e2 : hopP )
- if( e1 != e2 && e1.subsumes(e2) )
- rmSet.add(e2);
- memo.remove(current, rmSet);
- }
-
- //step 2: select plan for current path
- MemoTableEntry best = null;
- if( memo.contains(current.getHopID()) ) {
- if( currentType == null ) {
- best = memo.get(current.getHopID()).stream()
- .filter(p -> isValid(p, current))
- .min(BASE_COMPARE).orElse(null);
- }
- else {
- _typedCompare.setType(currentType);
- best = memo.get(current.getHopID()).stream()
- .filter(p -> p.type==currentType || p.type==TemplateType.CELL)
- .min(_typedCompare).orElse(null);
- }
- addBestPlan(current.getHopID(), best);
- }
-
- //step 3: recursively process children
- for( int i=0; i< current.getInput().size(); i++ ) {
- TemplateType pref = (best!=null && best.isPlanRef(i))? best.type : null;
- rSelectPlansFuseAll(memo, current.getInput().get(i), pref, partition);
- }
-
- setVisited(current.getHopID(), currentType);
+ visited.add(current.getHopID());
}
/////////////////////////////////////////////////////////
http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 06d83bd..4dc0bf2 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -204,6 +204,8 @@ public class TemplateUtils
return RowType.COL_AGG_B1_T;
else if( B1 != null && output.getDim1()==B1.getDim2() && output.getDim2()==X.getDim2())
return RowType.COL_AGG_B1;
+ else if( B1 != null && output.getDim1()==1 && B1.getDim2() == output.getDim2() )
+ return RowType.COL_AGG_B1R;
else if( X.getDim1() == output.getDim1() && X.getDim2() != output.getDim2() )
return RowType.NO_AGG_CONST;
else
http://git-wip-us.apache.org/repos/asf/systemml/blob/c27c488b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
index 8b12e7e..311c27f 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
@@ -47,22 +47,25 @@ public abstract class SpoofRowwise extends SpoofOperator
private static final long serialVersionUID = 6242910797139642998L;
public enum RowType {
- NO_AGG, //no aggregation
- NO_AGG_B1, //no aggregation w/ matrix mult B1
+ NO_AGG, //no aggregation
+ NO_AGG_B1, //no aggregation w/ matrix mult B1
NO_AGG_CONST, //no aggregation w/ expansion/contraction
- FULL_AGG, //full row/col aggregation
- ROW_AGG, //row aggregation (e.g., rowSums() or X %*% v)
- COL_AGG, //col aggregation (e.g., colSums() or t(y) %*% X)
- COL_AGG_T, //transposed col aggregation (e.g., t(X) %*% y)
+ FULL_AGG, //full row/col aggregation
+ ROW_AGG, //row aggregation (e.g., rowSums() or X %*% v)
+ COL_AGG, //col aggregation (e.g., colSums() or t(y) %*% X)
+ COL_AGG_T, //transposed col aggregation (e.g., t(X) %*% y)
COL_AGG_B1, //col aggregation w/ matrix mult B1
- COL_AGG_B1_T; //transposed col aggregation w/ matrix mult B1
+ COL_AGG_B1_T, //transposed col aggregation w/ matrix mult B1
+ COL_AGG_B1R; //col aggregation w/ matrix mult B1 to row vector
public boolean isColumnAgg() {
- return (this == COL_AGG || this == COL_AGG_T)
- || (this == COL_AGG_B1) || (this == COL_AGG_B1_T);
+ return this == COL_AGG || this == COL_AGG_T
+ || this == COL_AGG_B1 || this == COL_AGG_B1_T
+ || this == COL_AGG_B1R;
}
public boolean isRowTypeB1() {
- return (this == NO_AGG_B1) || (this == COL_AGG_B1) || (this == COL_AGG_B1_T);
+ return this == NO_AGG_B1 || this == COL_AGG_B1
+ || this == COL_AGG_B1_T || this == COL_AGG_B1R;
}
public boolean isRowTypeB1ColumnAgg() {
return (this == COL_AGG_B1) || (this == COL_AGG_B1_T);
@@ -268,7 +271,7 @@ public abstract class SpoofRowwise extends SpoofOperator
case COL_AGG_T: out.reset(n, 1, false); break;
case COL_AGG_B1: out.reset(n2, n, false); break;
case COL_AGG_B1_T: out.reset(n, n2, false); break;
-
+ case COL_AGG_B1R: out.reset(1, n2, false); break;
}
out.allocateDenseBlock();
}