You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:47 UTC

[21/23] systemml git commit: Add new `wumm` pattern to pick up element-wise multiply rewrite.

Add new `wumm` pattern to pick up element-wise multiply rewrite.

The new pattern recognizes when there is a `*2` or `2*` outside `W*(U%*%t(V))`.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/479b9da4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/479b9da4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/479b9da4

Branch: refs/heads/master
Commit: 479b9da4e6c605871a914ccb4b06ab6da5de21ed
Parents: e93c487
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Thu Jul 13 01:14:48 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Thu Jul 13 01:14:48 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/ProgramRewriter.java     |  2 +-
 .../RewriteAlgebraicSimplificationDynamic.java  | 44 +++++++++++++++++++-
 2 files changed, 43 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/479b9da4/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 59565df..7c4f861 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -54,7 +54,7 @@ public class ProgramRewriter
 	private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName());
 	
 	//internal local debug level
-	private static final boolean LDEBUG = false; 
+	private static final boolean LDEBUG = false;
 	private static final boolean CHECK = false;
 	
 	private ArrayList<HopRewriteRule> _dagRuleSet = null;

http://git-wip-us.apache.org/repos/asf/systemml/blob/479b9da4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 8cd71f4..6246270 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -29,11 +29,11 @@ import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.DataGenOp;
 import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.QuaternaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.DataGenMethod;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.Hop.OpOp3;
 import org.apache.sysml.hops.Hop.OpOp4;
 import org.apache.sysml.hops.Hop.ParamBuiltinOp;
@@ -44,7 +44,7 @@ import org.apache.sysml.hops.LeftIndexingOp;
 import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.hops.ParameterizedBuiltinOp;
-import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.QuaternaryOp;
 import org.apache.sysml.hops.ReorgOp;
 import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.UnaryOp;
@@ -1959,6 +1959,46 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			appliedPattern = true;
 			LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")");	
 		}
+
+		//Pattern 1.5) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
+		if( !appliedPattern
+				&& hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT)
+				&& (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2)
+					|| HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2)))
+		{
+			final Hop nl; // non-literal
+			if( hi.getInput().get(0) instanceof LiteralOp ) {
+				nl = hi.getInput().get(1);
+			} else {
+				nl = hi.getInput().get(0);
+			}
+
+			if (       HopRewriteUtils.isBinary(nl, OpOp2.MULT)
+					&& HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) //prevent mv
+					&& nl.getDim2() > 1 //not applied for vector-vector mult
+					&& nl.getInput().get(0).getDataType() == DataType.MATRIX
+					&& nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock()
+					&& HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1))
+					&& (((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
+					&& HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0),true) )
+			{
+				final Hop W = nl.getInput().get(0);
+				final Hop U = nl.getInput().get(1).getInput().get(0);
+				Hop V = nl.getInput().get(1).getInput().get(1);
+				if( !HopRewriteUtils.isTransposeOperation(V) )
+					V = HopRewriteUtils.createTranspose(V);
+				else
+					V = V.getInput().get(0);
+
+				hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+						OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
+				hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
+				hnew.refreshSizeInformation();
+
+				appliedPattern = true;
+				LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line "+hi.getBeginLine()+")");
+			}
+		}
 		
 		//Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
 		if( !appliedPattern