You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/02/12 19:32:18 UTC

[4/5] incubator-systemml git commit: [SYSTEMML-1254] New sum-product rewrites (agg pushdown), stratstats

[SYSTEMML-1254] New sum-product rewrites (agg pushdown), stratstats

In the spirit of our SPOOF compiler framework and the existing
sum(X%*%Y) rewrite, this patch adds the following two sum-product
rewrites (where the first applies multiple times in stratstats):

* colSums(X %*% Y) -> colsSums(X) %*% Y
* rowSums(X %*% Y) -> X %*% rowSums(Y)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b3ba9916
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b3ba9916
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b3ba9916

Branch: refs/heads/master
Commit: b3ba991604cf79c5e3e2c0992fe2439ae47ce023
Parents: d3e617b
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun Feb 12 08:29:36 2017 +0100
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sun Feb 12 09:42:03 2017 +0100

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationDynamic.java  | 54 +++++++++++---------
 1 file changed, 30 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b3ba9916/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index e8e3862..6ffcbd5 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2547,11 +2547,12 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 
 	private Hop simplifySumMatrixMult(Hop parent, Hop hi, int pos)
 	{
-		//sum(A%*%B) -> sum(t(colSums(A))*rowSums(B))
-		//if not dot product, not applied since aggregate removed
-		//if sum not the only consumer, not applied to prevent redundancy 
+		//sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product
+		//colSums(A%*%B) -> colSums(A)%*%B
+		//rowSums(A%*%B) -> A%*%rowSums(B)
+		//-- if not dot product, not applied since aggregate removed
+		//-- if sum not the only consumer, not applied to prevent redundancy 
 		if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM  //sum
-			&& ((AggUnaryOp)hi).getDirection() == Direction.RowCol	         //full aggregate
 			&& hi.getInput().get(0) instanceof AggBinaryOp                   //A%*%B
 			&& (hi.getInput().get(0).getDim1()>1 || hi.getInput().get(0).getDim2()>1) //not dot product
 			&& hi.getInput().get(0).getParent().size()==1 )     //not multiple consumers of matrix mult
@@ -2560,34 +2561,39 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			Hop left = hi2.getInput().get(0);
 			Hop right = hi2.getInput().get(1);
 				
-			//remove link from parent to diag
+			//remove link from parent to matrix mult
 			HopRewriteUtils.removeChildReference(hi, hi2);
 				
 			//create new operators
-			AggUnaryOp colSum = new AggUnaryOp(left.getName(), left.getDataType(), left.getValueType(), AggOp.SUM, Direction.Col, left);
-			colSum.setRowsInBlock(left.getRowsInBlock());
-			colSum.setColsInBlock(left.getColsInBlock());
-			colSum.refreshSizeInformation();
-			ReorgOp trans = HopRewriteUtils.createTranspose(colSum);
-			AggUnaryOp rowSum = new AggUnaryOp(right.getName(), right.getDataType(), right.getValueType(), AggOp.SUM, Direction.Row, right);
-			rowSum.setRowsInBlock(right.getRowsInBlock());
-			rowSum.setColsInBlock(right.getColsInBlock());
-			rowSum.refreshSizeInformation();
-			BinaryOp mult = new BinaryOp(right.getName(), right.getDataType(), right.getValueType(), OpOp2.MULT, trans, rowSum);
-			mult.setRowsInBlock(right.getRowsInBlock());
-			mult.setColsInBlock(right.getColsInBlock());
-			mult.refreshSizeInformation();
-				
+			Hop root = null;
+			//pattern: sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product
+			if( ((AggUnaryOp)hi).getDirection() == Direction.RowCol ) {
+				AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
+				ReorgOp trans = HopRewriteUtils.createTranspose(colSum);
+				AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
+				root = HopRewriteUtils.createBinary(trans, rowSum, OpOp2.MULT);
+				LOG.debug("Applied simplifySumMatrixMult RC.");
+			}
+			//colSums(A%*%B) -> colSums(A)%*%B
+			else if( ((AggUnaryOp)hi).getDirection() == Direction.Col ) {
+				AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
+				root = HopRewriteUtils.createMatrixMultiply(colSum, right);
+				LOG.debug("Applied simplifySumMatrixMult C.");
+			}
+			//rowSums(A%*%B) -> A%*%rowSums(B)
+			else if( ((AggUnaryOp)hi).getDirection() == Direction.Row ) {
+				AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
+				root = HopRewriteUtils.createMatrixMultiply(left, rowSum);
+				LOG.debug("Applied simplifySumMatrixMult R.");
+			}
 			
 			//rehang new subdag under current node (keep hi intact)
-			HopRewriteUtils.addChildReference(hi, mult, 0);				
+			HopRewriteUtils.addChildReference(hi, root, 0);				
 			hi.refreshSizeInformation();
-				
+			
 			//cleanup if only consumer of intermediate
 			if( hi2.getParent().isEmpty() ) 
-				HopRewriteUtils.removeAllChildReferences( hi2 );
-			
-			LOG.debug("Applied simplifySumMatrixMult.");	
+				HopRewriteUtils.removeAllChildReferences( hi2 );	
 		}
 		
 		return hi;