You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/03/21 19:30:34 UTC

incubator-systemml git commit: New simplification rewrite 'simplify transpose-aggbinary-binary chains'

Repository: incubator-systemml
Updated Branches:
  refs/heads/master ab5b27c97 -> cc927b71d


New simplification rewrite 'simplify transpose-aggbinary-binary chains'

Our existing transpose-multiply rewrite (e.g., t(X)%*%y -> t(t(t)%*%X)
only covers transpose-matrixmult and is applied via hop-lop rewrites as
late as possible to account for size information. This patch introduces
a new static algebraic simplification rewrite: t(t(A)%*%t(B)+C) ->
B%*%A+t(C), where we support arbitrary "basic" binary operations. Note
that this rewrite can be applied without dimension information because
t(C) is at most as large as the removed output transpose but
additionally we removed both input transposes t(A) and t(B). 

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/cc927b71
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/cc927b71
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/cc927b71

Branch: refs/heads/master
Commit: cc927b71d9bb879bb84681f5efa0f887ccc051da
Parents: ab5b27c
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sun Mar 20 23:22:25 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sun Mar 20 23:22:25 2016 -0700

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationStatic.java   | 45 ++++++++++++++++++++
 1 file changed, 45 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc927b71/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 2e2ec79..37a0dc2 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -151,6 +151,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			hi = simplifyConstantSort(hop, hi, i);               //e.g., order(matrix())->matrix/seq; 
 			hi = simplifyOrderedSort(hop, hi, i);                //e.g., order(matrix())->seq; 
 			hi = removeUnnecessaryReorgOperation(hop, hi, i);    //e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites
+			hi = simplifyTransposeAggBinBinaryChains(hop, hi, i);//e.g., t(t(A)%*%t(B)+C) -> B%*%A+t(C)
 			hi = removeUnnecessaryMinus(hop, hi, i);             //e.g., -(-X)->X; potentially introduced by simplfiy binary or dyn rewrites
 			hi = simplifyGroupedAggregate(hi);          	     //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count")
 			hi = fuseMinusNzBinaryOperation(hop, hi, i);         //e.g., X-mean*ppred(X,0,!=) -> X -nz mean
@@ -1365,6 +1366,50 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 		
 		return hi;
 	}
+
+	/**
+	 * Patterns: t(t(A)%*%t(B)+C) -> B%*%A+t(C)
+	 * 
+	 * @param parent
+	 * @param hi
+	 * @param pos
+	 * @return
+	 * @throws HopsException
+	 */
+	private Hop simplifyTransposeAggBinBinaryChains(Hop parent, Hop hi, int pos) 
+		throws HopsException
+	{
+		if( hi instanceof ReorgOp && ((ReorgOp)hi).getOp()==ReOrgOp.TRANSPOSE //transpose
+			&& hi.getInput().get(0) instanceof BinaryOp                       //basic binary
+			&& ((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations()) 
+		{
+			Hop left = hi.getInput().get(0).getInput().get(0);
+			Hop C = hi.getInput().get(0).getInput().get(1);
+			
+			//check matrix mult and both inputs transposes w/ single consumer
+			if( left instanceof AggBinaryOp && C.getDataType().isMatrix()
+				&& left.getInput().get(0).getParent().size()==1 && left.getInput().get(0) instanceof ReorgOp
+				&& ((ReorgOp)left.getInput().get(0)).getOp()==ReOrgOp.TRANSPOSE     
+				&& left.getInput().get(1).getParent().size()==1 && left.getInput().get(1) instanceof ReorgOp
+				&& ((ReorgOp)left.getInput().get(1)).getOp()==ReOrgOp.TRANSPOSE )
+			{
+				Hop A = left.getInput().get(0).getInput().get(0);
+				Hop B = left.getInput().get(1).getInput().get(0);
+				
+				AggBinaryOp abop = HopRewriteUtils.createMatrixMultiply(B, A);
+				ReorgOp rop = HopRewriteUtils.createTranspose(C);
+				BinaryOp bop = HopRewriteUtils.createBinary(abop, rop, OpOp2.PLUS);
+				
+				HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+				HopRewriteUtils.addChildReference(parent, bop, pos);
+				
+				hi = bop;
+				LOG.debug("Applied simplifyTransposeAggBinBinaryChains (line "+hi.getBeginLine()+").");						
+			}  
+		}
+		
+		return hi;
+	}
 	
 	/**
 	 * Pattners: t(t(X)) -> X, rev(rev(X)) -> X