You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/09/27 06:39:29 UTC
[2/2] systemml git commit: [SYSTEMML-1933] Generalized codegen cbind
handling (part 2), cleanups
[SYSTEMML-1933] Generalized codegen cbind handling (part 2), cleanups
This patch finalizes the codegen cbind generalization. We now do not
just fuse cbinds w/ constant vectors but arbitrary vector inputs. This
significantly extended its applicability and also revealed a number of
smaller robustness issues that needed fixing (e.g., row type selection,
row indexing on main input, switch from row to cell template).
On GLM-probit (100M x 10, 20/10 iterations) this patch improved
end-to-end performance (w/ codegen) from 337s to 185s.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/328e8a00
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/328e8a00
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/328e8a00
Branch: refs/heads/master
Commit: 328e8a0020c17c072f13d9a1bc9334af968b9c2b
Parents: 682fc44
Author: Matthias Boehm <mb...@gmail.com>
Authored: Tue Sep 26 22:24:54 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Tue Sep 26 23:38:54 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/codegen/SpoofCompiler.java | 3 +--
.../sysml/hops/codegen/cplan/CNodeBinary.java | 3 ++-
.../sysml/hops/codegen/cplan/CNodeUnary.java | 4 ++-
.../sysml/hops/codegen/opt/PlanSelection.java | 12 ++++++---
.../opt/PlanSelectionFuseCostBasedV2.java | 14 ++++++-----
.../hops/codegen/template/TemplateRow.java | 26 ++++++++++++++------
.../hops/codegen/template/TemplateUtils.java | 3 ++-
7 files changed, 43 insertions(+), 22 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index 1db2910..a4a68bb 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -718,8 +718,7 @@ public class SpoofCompiler
//remove invalid row templates (e.g., unsatisfied blocksize constraint)
if( tpl instanceof CNodeRow ) {
//check for invalid row cplan over column vector
- if(in1.getNumCols() == 1 || (((CNodeRow)tpl).getRowType()==RowType.NO_AGG
- && tpl.getOutput().getDataType().isScalar()) ) {
+ if( ((CNodeRow)tpl).getRowType()==RowType.NO_AGG && tpl.getOutput().getDataType().isScalar() ) {
cplans2.remove(e.getKey());
if( LOG.isTraceEnabled() )
LOG.trace("Removed invalid row cplan w/o agg on column vector.");
http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
index c2b5644..42a36ac 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
@@ -270,7 +270,8 @@ public class CNodeBinary extends CNode
//generate binary operation (use sparse template, if data input)
boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData
- && _inputs.get(0).getVarname().startsWith("a")
+ && (_inputs.get(0).getVarname().startsWith("a")
+ || _inputs.get(1).getVarname().startsWith("a"))
&& !_inputs.get(0).isLiteral());
boolean scalarInput = _inputs.get(0).getDataType().isScalar();
boolean scalarVector = (_inputs.get(0).getDataType().isScalar()
http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index b3720dd..860d35a 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -87,7 +87,9 @@ public class CNodeUnary extends CNode
case EXP:
return " double %TMP% = FastMath.exp(%IN1%);\n";
case LOOKUP_R:
- return " double %TMP% = getValue(%IN1%, rowIndex);\n";
+ return sparse ?
+ " double %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, 0);\n" :
+ " double %TMP% = getValue(%IN1%, rowIndex);\n";
case LOOKUP_C:
return " double %TMP% = getValue(%IN1%, n, 0, colIndex);\n";
case LOOKUP_RC:
http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
index d18d156..21f4fd3 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelection.java
@@ -47,18 +47,22 @@ public abstract class PlanSelection
* @param memo partial fusion plans P
* @param roots entry points of HOP DAG G
*/
- public abstract void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots);
+ public abstract void selectPlans(CPlanMemoTable memo, ArrayList<Hop> roots);
/**
- * Determines if the given partial fusion plan is valid.
+ * Determines if the given partial fusion plan is a valid entry point
+ * of a fused operator.
*
* @param me memo table entry
* @param hop current hop
* @return true if entry is valid as top-level plan
*/
public static boolean isValid(MemoTableEntry me, Hop hop) {
- return (me.type != TemplateType.OUTER //ROW, CELL, MAGG
- || (me.closed || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop)));
+ return (me.type == TemplateType.CELL)
+ || (me.type == TemplateType.MAGG)
+ || (me.type == TemplateType.ROW && !HopRewriteUtils.isTransposeOperation(hop))
+ || (me.type == TemplateType.OUTER
+ && (me.closed || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop)));
}
protected void addBestPlan(long hopID, MemoTableEntry me) {
http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
index 7c27dcf..8d1c4c0 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/opt/PlanSelectionFuseCostBasedV2.java
@@ -43,6 +43,7 @@ import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.OptimizerUtils;
@@ -568,18 +569,18 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
}
}
- private static boolean isRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
+ private static boolean isRowTemplateWithoutAggOrVects(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
//consider all aggregations other than root operation
MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
boolean ret = true;
for(int i=0; i<3; i++)
if( me.isPlanRef(i) )
- ret &= rIsRowTemplateWithoutAgg(memo,
+ ret &= rIsRowTemplateWithoutAggOrVects(memo,
current.getInput().get(i), visited);
return ret;
}
- private static boolean rIsRowTemplateWithoutAgg(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
+ private static boolean rIsRowTemplateWithoutAggOrVects(CPlanMemoTable memo, Hop current, HashSet<Long> visited) {
if( visited.contains(current.getHopID()) )
return true;
@@ -587,8 +588,9 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
MemoTableEntry me = memo.getBest(current.getHopID(), TemplateType.ROW);
for(int i=0; i<3; i++)
if( me!=null && me.isPlanRef(i) )
- ret &= rIsRowTemplateWithoutAgg(memo, current.getInput().get(i), visited);
- ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp);
+ ret &= rIsRowTemplateWithoutAggOrVects(memo, current.getInput().get(i), visited);
+ ret &= !(current instanceof AggUnaryOp || current instanceof AggBinaryOp
+ || HopRewriteUtils.isBinary(current, OpOp2.CBIND));
visited.add(current.getHopID());
return ret;
@@ -628,7 +630,7 @@ public class PlanSelectionFuseCostBasedV2 extends PlanSelection
for( Long hopID : part.getPartition() ) {
MemoTableEntry me = memo.getBest(hopID, TemplateType.ROW);
if( me != null && me.type == TemplateType.ROW && memo.contains(hopID, TemplateType.CELL)
- && isRowTemplateWithoutAgg(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) {
+ && isRowTemplateWithoutAggOrVects(memo, memo.getHopRefs().get(hopID), new HashSet<Long>())) {
List<MemoTableEntry> blacklist = memo.get(hopID, TemplateType.ROW);
memo.remove(memo.getHopRefs().get(hopID), new HashSet<MemoTableEntry>(blacklist));
if( LOG.isTraceEnabled() ) {
http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index d9209be..1aaa84f 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -78,7 +78,7 @@ public class TemplateRow extends TemplateBase
return (hop instanceof BinaryOp && hop.dimsKnown() && isValidBinaryOperation(hop)
&& hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1)
|| (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix()
- && HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1)))
+ && hop.dimsKnown() && TemplateUtils.isColVector(hop.getInput().get(1)))
|| (hop instanceof AggBinaryOp && hop.dimsKnown() && hop.getDim2()==1 //MV
&& hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1)
|| (hop instanceof AggBinaryOp && hop.dimsKnown() && LibMatrixMult.isSkinnyRightHandSide(
@@ -101,9 +101,9 @@ public class TemplateRow extends TemplateBase
return !isClosed() &&
( (hop instanceof BinaryOp && isValidBinaryOperation(hop) )
|| (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().indexOf(input)==0
- && HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1)))
+ && hop.dimsKnown() && TemplateUtils.isColVector(hop.getInput().get(1)))
|| ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp)
- && TemplateCell.isValidOperation(hop))
+ && TemplateCell.isValidOperation(hop))
|| (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol
&& HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG))
|| (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection() == Direction.RowCol
@@ -121,7 +121,9 @@ public class TemplateRow extends TemplateBase
//merge rowagg tpl with cell tpl if input is a vector
return !isClosed() &&
((hop instanceof BinaryOp && isValidBinaryOperation(hop)
- && hop.getDim1() > 1 && input.getDim1()>1)
+ && hop.getDim1() > 1 && input.getDim1()>1)
+ || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix()
+ && hop.dimsKnown() && TemplateUtils.isColVector(hop.getInput().get(1)))
||(hop instanceof AggBinaryOp
&& HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))
&& (input.getDim2()==1 || (input==hop.getInput().get(1)
@@ -184,6 +186,7 @@ public class TemplateRow extends TemplateBase
Hop[] sinHops = inHops.stream()
.filter(h -> !(h.getDataType().isScalar() && tmp.get(h.getHopID()).isLiteral()))
.sorted(new HopInputComparator(inHops2.get("X"),inHops2.get("B1"))).toArray(Hop[]::new);
+ inHops2.putIfAbsent("X", sinHops[0]); //robustness special cases
//construct template node
ArrayList<CNode> inputs = new ArrayList<CNode>();
@@ -326,10 +329,19 @@ public class TemplateRow extends TemplateBase
{
//special case for cbind with zeros
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
- CNode cdata2 = TemplateUtils.createCNodeData(
- HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true);
+ CNode cdata2 = null;
+ if( HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1)) ) {
+ cdata2 = TemplateUtils.createCNodeData(HopRewriteUtils
+ .getDataGenOpConstantValue(hop.getInput().get(1)), true);
+ inHops.remove(hop.getInput().get(1)); //rm 0-matrix
+ }
+ else {
+ cdata2 = tmp.get(hop.getInput().get(1).getHopID());
+ cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
+ }
out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND);
- inHops.remove(hop.getInput().get(1)); //rm 0-matrix
+ if( cdata1 instanceof CNodeData )
+ inHops2.put("X", hop.getInput().get(0));
}
else if(hop instanceof BinaryOp)
{
http://git-wip-us.apache.org/repos/asf/systemml/blob/328e8a00/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 9d7baf9..21f44b2 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -235,7 +235,8 @@ public class TemplateUtils
}
public static boolean isLookup(CNode node, boolean includeRC1) {
- return isUnary(node, UnaryType.LOOKUP_R, UnaryType.LOOKUP_C, UnaryType.LOOKUP_RC)
+ return isUnary(node, UnaryType.LOOKUP_C, UnaryType.LOOKUP_RC)
+ || (includeRC1 && isUnary(node, UnaryType.LOOKUP_R))
|| (includeRC1 && isTernary(node, TernaryType.LOOKUP_RC1));
}