You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/03/21 19:30:34 UTC
incubator-systemml git commit: New simplification rewrite 'simplify
transpose-aggbinary-binary chains'
Repository: incubator-systemml
Updated Branches:
refs/heads/master ab5b27c97 -> cc927b71d
New simplification rewrite 'simplify transpose-aggbinary-binary chains'
Our existing transpose-multiply rewrite (e.g., t(X)%*%y -> t(t(t)%*%X)
only covers transpose-matrixmult and is applied via hop-lop rewrites as
late as possible to account for size information. This patch introduces
a new static algebraic simplification rewrite: t(t(A)%*%t(B)+C) ->
B%*%A+t(C), where we support arbitrary "basic" binary operations. Note
that this rewrite can be applied without dimension information because
t(C) is at most as large as the removed output transpose but
additionally we removed both input transposes t(A) and t(B).
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/cc927b71
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/cc927b71
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/cc927b71
Branch: refs/heads/master
Commit: cc927b71d9bb879bb84681f5efa0f887ccc051da
Parents: ab5b27c
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sun Mar 20 23:22:25 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sun Mar 20 23:22:25 2016 -0700
----------------------------------------------------------------------
.../RewriteAlgebraicSimplificationStatic.java | 45 ++++++++++++++++++++
1 file changed, 45 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc927b71/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 2e2ec79..37a0dc2 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -151,6 +151,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
hi = simplifyConstantSort(hop, hi, i); //e.g., order(matrix())->matrix/seq;
hi = simplifyOrderedSort(hop, hi, i); //e.g., order(matrix())->seq;
hi = removeUnnecessaryReorgOperation(hop, hi, i); //e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites
+ hi = simplifyTransposeAggBinBinaryChains(hop, hi, i);//e.g., t(t(A)%*%t(B)+C) -> B%*%A+t(C)
hi = removeUnnecessaryMinus(hop, hi, i); //e.g., -(-X)->X; potentially introduced by simplfiy binary or dyn rewrites
hi = simplifyGroupedAggregate(hi); //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count")
hi = fuseMinusNzBinaryOperation(hop, hi, i); //e.g., X-mean*ppred(X,0,!=) -> X -nz mean
@@ -1365,6 +1366,50 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
return hi;
}
+
+ /**
+ * Patterns: t(t(A)%*%t(B)+C) -> B%*%A+t(C)
+ *
+ * @param parent
+ * @param hi
+ * @param pos
+ * @return
+ * @throws HopsException
+ */
+ private Hop simplifyTransposeAggBinBinaryChains(Hop parent, Hop hi, int pos)
+ throws HopsException
+ {
+ if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.TRANSPOSE //transpose
+ && hi.getInput().get(0) instanceof BinaryOp //basic binary
+ && ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations())
+ {
+ Hop left = hi.getInput().get(0).getInput().get(0);
+ Hop C = hi.getInput().get(0).getInput().get(1);
+
+ //check matrix mult and both inputs transposes w/ single consumer
+ if( left instanceof AggBinaryOp && C.getDataType().isMatrix()
+ && left.getInput().get(0).getParent().size()==1 && left.getInput().get(0) instanceof ReorgOp
+ && ((ReorgOp)left.getInput().get(0)).getOp()==ReOrgOp.TRANSPOSE
+ && left.getInput().get(1).getParent().size()==1 && left.getInput().get(1) instanceof ReorgOp
+ && ((ReorgOp)left.getInput().get(1)).getOp()==ReOrgOp.TRANSPOSE )
+ {
+ Hop A = left.getInput().get(0).getInput().get(0);
+ Hop B = left.getInput().get(1).getInput().get(0);
+
+ AggBinaryOp abop = HopRewriteUtils.createMatrixMultiply(B, A);
+ ReorgOp rop = HopRewriteUtils.createTranspose(C);
+ BinaryOp bop = HopRewriteUtils.createBinary(abop, rop, OpOp2.PLUS);
+
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.addChildReference(parent, bop, pos);
+
+ hi = bop;
+ LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+").");
+ }
+ }
+
+ return hi;
+ }
/**
* Pattners: t(t(X)) -> X, rev(rev(X)) -> X