You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/10/19 23:05:48 UTC
systemml git commit: [SYSTEMML-1903] Fix robustness codegen row ops
w/ unknowns
Repository: systemml
Updated Branches:
refs/heads/master 4f29b3485 -> 323dd72a8
[SYSTEMML-1903] Fix robustness codegen row ops w/ unknowns
This patch fixes special cases of codegen row templates with partial
unknowns, which is important for robustness during initial compilation
even though the unknowns led to dynamic recompilation during runtime.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/323dd72a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/323dd72a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/323dd72a
Branch: refs/heads/master
Commit: 323dd72a8ed18687aa3019387c4ab7b0598bd9d5
Parents: 4f29b34
Author: Matthias Boehm <mb...@gmail.com>
Authored: Thu Oct 19 15:07:54 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Thu Oct 19 16:06:14 2017 -0700
----------------------------------------------------------------------
.../hops/codegen/template/TemplateRow.java | 38 ++++++++++----------
.../hops/codegen/template/TemplateUtils.java | 2 +-
2 files changed, 21 insertions(+), 19 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/323dd72a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index 0389983..e664b9f 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -250,7 +250,7 @@ public class TemplateRow extends TemplateBase
else if (((AggUnaryOp)hop).getDirection() == Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) {
//vector add without temporary copy
if( cdata1 instanceof CNodeBinary && ((CNodeBinary)cdata1).getType().isVectorScalarPrimitive() )
- out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1),
+ out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1),
((CNodeBinary)cdata1).getType().getVectorAddPrimitive());
else
out = cdata1;
@@ -269,7 +269,7 @@ public class TemplateRow extends TemplateBase
{
//correct input under transpose
cdata1 = TemplateUtils.skipTranspose(cdata1, hop.getInput().get(0), tmp, compileLiterals);
- inHops.remove(hop.getInput().get(0));
+ inHops.remove(hop.getInput().get(0));
inHops.add(hop.getInput().get(0).getInput().get(0));
//note: vectorMultAdd applicable to vector-scalar, and vector-vector
@@ -310,7 +310,8 @@ public class TemplateRow extends TemplateBase
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
// if one input is a matrix then we need to do vector by scalar operations
- if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 )
+ if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1
+ || (!hop.dimsKnown() && cdata1.getDataType()==DataType.MATRIX ) )
{
if( HopRewriteUtils.isUnary(hop, SUPPORTED_VECT_UNARY) ) {
String opname = "VECT_"+((UnaryOp)hop).getOp().name();
@@ -320,12 +321,11 @@ public class TemplateRow extends TemplateBase
}
else
throw new RuntimeException("Unsupported unary matrix "
- + "operation: " + ((UnaryOp)hop).getOp().name());
+ + "operation: " + ((UnaryOp)hop).getOp().name());
}
else //general scalar case
{
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
-
String primitiveOpName = ((UnaryOp)hop).getOp().toString();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
}
@@ -355,7 +355,9 @@ public class TemplateRow extends TemplateBase
// if one input is a matrix then we need to do vector by scalar operations
if( (hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1)
- || (hop.getInput().get(1).getDim1() > 1 && hop.getInput().get(1).getDim2() > 1))
+ || (hop.getInput().get(1).getDim1() > 1 && hop.getInput().get(1).getDim2() > 1)
+ || (!(hop.dimsKnown() && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown())
+ && (cdata1.getDataType().isMatrix() || cdata2.getDataType().isMatrix())))
{
if( HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY) ) {
if( TemplateUtils.isMatrix(cdata1) && (TemplateUtils.isMatrix(cdata2)
@@ -371,14 +373,14 @@ public class TemplateRow extends TemplateBase
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname));
}
- if( cdata1 instanceof CNodeData && inHops2.isEmpty()
+ if( cdata1 instanceof CNodeData && inHops2.isEmpty()
&& !(cdata1.getDataType()==DataType.SCALAR) ) {
inHops2.put("X", hop.getInput().get(0));
}
}
else
throw new RuntimeException("Unsupported binary matrix "
- + "operation: " + ((BinaryOp)hop).getOp().name());
+ + "operation: " + ((BinaryOp)hop).getOp().name());
}
else //one input is a vector/scalar other is a scalar
{
@@ -389,7 +391,7 @@ public class TemplateRow extends TemplateBase
|| (TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData
&& hop.getInput().get(1).getDataType().isMatrix()))
cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R);
- out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
+ out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName));
}
}
else if(hop instanceof TernaryOp)
@@ -405,16 +407,16 @@ public class TemplateRow extends TemplateBase
//construct ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3,
- TernaryType.valueOf(top.getOp().toString()));
+ TernaryType.valueOf(top.getOp().toString()));
}
- else if( hop instanceof ParameterizedBuiltinOp )
+ else if( hop instanceof ParameterizedBuiltinOp )
{
CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
- TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ?
+ TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ?
TernaryType.REPLACE_NAN : TernaryType.REPLACE;
out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
}
@@ -422,7 +424,7 @@ public class TemplateRow extends TemplateBase
{
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
out = new CNodeTernary(cdata1,
- TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true),
+ TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true),
TemplateUtils.createCNodeData(hop.getInput().get(4), true),
(!hop.dimsKnown()||hop.getDim2()>1) ? TernaryType.LOOKUP_RVECT1 : TernaryType.LOOKUP_RC1);
}
@@ -456,13 +458,13 @@ public class TemplateRow extends TemplateBase
@Override
public int compare(Hop h1, Hop h2) {
- long ncells1 = h1.isScalar() ? Long.MIN_VALUE :
- (h1==_X) ? Long.MAX_VALUE : (h1==_B1) ? Long.MAX_VALUE-1 :
+ long ncells1 = h1.isScalar() ? Long.MIN_VALUE :
+ (h1==_X) ? Long.MAX_VALUE : (h1==_B1) ? Long.MAX_VALUE-1 :
h1.dimsKnown() ? h1.getLength() : Long.MAX_VALUE-2;
- long ncells2 = h2.isScalar() ? Long.MIN_VALUE :
- (h2==_X) ? Long.MAX_VALUE : (h2==_B1) ? Long.MAX_VALUE-1 :
+ long ncells2 = h2.isScalar() ? Long.MIN_VALUE :
+ (h2==_X) ? Long.MAX_VALUE : (h2==_B1) ? Long.MAX_VALUE-1 :
h2.dimsKnown() ? h2.getLength() : Long.MAX_VALUE-2;
- return (ncells1 > ncells2) ? -1 : (ncells1 < ncells2) ? 1 : 0;
+ return (ncells1 > ncells2) ? -1 : (ncells1 < ncells2) ? 1 : 0;
}
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/323dd72a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 497dae0..96e15cb 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -184,7 +184,7 @@ public class TemplateUtils
public static RowType getRowType(Hop output, Hop... inputs) {
Hop X = inputs[0];
Hop B1 = (inputs.length>1) ? inputs[1] : null;
- if( (X!=null && HopRewriteUtils.isEqualSize(output, X)) || X==null )
+ if( (X!=null && HopRewriteUtils.isEqualSize(output, X)) || X==null || !X.dimsKnown() )
return RowType.NO_AGG;
else if( ((B1!=null && output.getDim1()==X.getDim1() && output.getDim2()==B1.getDim2())
|| (output instanceof IndexingOp && HopRewriteUtils.isColumnRangeIndexing((IndexingOp)output)))