You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:47 UTC
[21/23] systemml git commit: Add new `wumm` pattern to pick up
element-wise multiply rewrite.
Add new `wumm` pattern to pick up element-wise multiply rewrite.
The new pattern recognizes when there is a `*2` or `2*` outside `W*(U%*%t(V))`.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/479b9da4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/479b9da4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/479b9da4
Branch: refs/heads/master
Commit: 479b9da4e6c605871a914ccb4b06ab6da5de21ed
Parents: e93c487
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Thu Jul 13 01:14:48 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Thu Jul 13 01:14:48 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/rewrite/ProgramRewriter.java | 2 +-
.../RewriteAlgebraicSimplificationDynamic.java | 44 +++++++++++++++++++-
2 files changed, 43 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/479b9da4/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 59565df..7c4f861 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -54,7 +54,7 @@ public class ProgramRewriter
private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName());
//internal local debug level
- private static final boolean LDEBUG = false;
+ private static final boolean LDEBUG = false;
private static final boolean CHECK = false;
private ArrayList<HopRewriteRule> _dagRuleSet = null;
http://git-wip-us.apache.org/repos/asf/systemml/blob/479b9da4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 8cd71f4..6246270 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -29,11 +29,11 @@ import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataGenOp;
import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.QuaternaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.DataGenMethod;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.Hop.OpOp3;
import org.apache.sysml.hops.Hop.OpOp4;
import org.apache.sysml.hops.Hop.ParamBuiltinOp;
@@ -44,7 +44,7 @@ import org.apache.sysml.hops.LeftIndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
-import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.QuaternaryOp;
import org.apache.sysml.hops.ReorgOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
@@ -1959,6 +1959,46 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")");
}
+
+ //Pattern 1.5) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
+ if( !appliedPattern
+ && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT)
+ && (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2)
+ || HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2)))
+ {
+ final Hop nl; // non-literal
+ if( hi.getInput().get(0) instanceof LiteralOp ) {
+ nl = hi.getInput().get(1);
+ } else {
+ nl = hi.getInput().get(0);
+ }
+
+ if ( HopRewriteUtils.isBinary(nl, OpOp2.MULT)
+ && HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) //prevent mv
+ && nl.getDim2() > 1 //not applied for vector-vector mult
+ && nl.getInput().get(0).getDataType() == DataType.MATRIX
+ && nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock()
+ && HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1))
+ && (((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
+ && HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0),true) )
+ {
+ final Hop W = nl.getInput().get(0);
+ final Hop U = nl.getInput().get(1).getInput().get(0);
+ Hop V = nl.getInput().get(1).getInput().get(1);
+ if( !HopRewriteUtils.isTransposeOperation(V) )
+ V = HopRewriteUtils.createTranspose(V);
+ else
+ V = V.getInput().get(0);
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
+ hnew.refreshSizeInformation();
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line "+hi.getBeginLine()+")");
+ }
+ }
//Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
if( !appliedPattern