You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/07/17 03:35:25 UTC
[3/3] incubator-systemml git commit: [SYSTEMML-766] Fix rewrite 'fuse
binary axpy' (missing blocksize info)
[SYSTEMML-766] Fix rewrite 'fuse binary axpy' (missing blocksize info)
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2b7fdb2b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2b7fdb2b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2b7fdb2b
Branch: refs/heads/master
Commit: 2b7fdb2b36df0aedc6b92bf138e7f0074eed7762
Parents: cbc4509
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sat Jul 16 20:32:28 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sat Jul 16 20:32:28 2016 -0700
----------------------------------------------------------------------
.../rewrite/RewriteAlgebraicSimplificationStatic.java | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2b7fdb2b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 43d5791..816b55a 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -165,7 +165,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false)
hi = simplifyTableSeqExpand(hop, hi, i); //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true)
- hi = fuseBinaryOperationChain(hop, hi, i); //e.g., X + lamda*Y -> X +* lambda Y
+ hi = fuseBinaryOperationChain(hop, hi, i); //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y)
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
@@ -1922,11 +1922,11 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
//pattern: X + lamda*Y -> X +* lambda Y
if( hi instanceof BinaryOp
&& (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS)
- && ((BinaryOp)hi).getInput().get(0).getDataType()==DataType.MATRIX && ((BinaryOp)hi).getInput().get(1) instanceof BinaryOp
+ && hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1) instanceof BinaryOp
&& (DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE || OptimizerUtils.isSparkExecutionMode()) )
{
//Check that the inner binary Op is a product of Scalar times Matrix or viceversa
- Hop innerBinaryOp = ((BinaryOp)hi).getInput().get(1);
+ Hop innerBinaryOp = hi.getInput().get(1);
if ( (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR && innerBinaryOp.getInput().get(1).getDataType()==DataType.MATRIX)
|| (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX && innerBinaryOp.getInput().get(1).getDataType()==DataType.SCALAR))
{
@@ -1934,8 +1934,9 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
Hop lamda = (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1);
Hop matrix = (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1);
- OpOp3 operator = (((BinaryOp)hi).getOp()==OpOp2.PLUS) ? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT;
- TernaryOp ternOp=new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, operator, ((BinaryOp)hi).getInput().get(0), lamda, matrix);
+ OpOp3 op = (((BinaryOp)hi).getOp()==OpOp2.PLUS) ? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT;
+ TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, hi.getInput().get(0), lamda, matrix);
+ HopRewriteUtils.refreshOutputParameters(ternOp, hi.getInput().get(0));
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
HopRewriteUtils.addChildReference(parent, ternOp, pos);
@@ -1944,7 +1945,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
return ternOp;
}
}
+
return hi;
-
}
}