You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/07/29 18:56:15 UTC

[systemds] branch master updated: [SYSTEMDS-3076] Additional hop rewrites for colMeans sequences

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 5a8c979  [SYSTEMDS-3076] Additional hop rewrites for colMeans sequences
5a8c979 is described below

commit 5a8c979bdadc8285a63979e5894d80cf6d94dcd8
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Thu Jul 29 20:34:32 2021 +0200

    [SYSTEMDS-3076] Additional hop rewrites for colMeans sequences
    
    This patch adds two new rewrites that help remove unnecessary operations
    in PCA with shifting (see functions/compress/WorkloadAlgorithmTest):
    
    1) colSums(X) / N -> colMeans(X) (precondition and fewer ops)
    2) colMeans((X-colMeans(X))/...) -> matrix(0,1,ncol(X))
    
    After these rewrites have been applied, various additional rewrites
    trigger to remove unnecessary terms with empty matrix multiplications.
---
 src/main/java/org/apache/sysds/common/Types.java   |  7 ++-
 .../RewriteAlgebraicSimplificationDynamic.java     | 60 ++++++++++++++++++++--
 .../RewriteAlgebraicSimplificationStatic.java      |  2 +-
 3 files changed, 62 insertions(+), 7 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index badf307..da53091 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -144,7 +144,12 @@ public class Types
 		RowCol, // full aggregate
 		Row,    // row aggregate (e.g., rowSums)
 		Col;    // column aggregate (e.g., colSums)
-		
+		public boolean isRow() {
+			return this == Row;
+		}
+		public boolean isCol() {
+			return this == Col;
+		}
 		@Override
 		public String toString() {
 			switch(this) {
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 269050c..5a29e6b 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -160,13 +160,15 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 				hi = removeUnnecessaryAppendTSMM(hop, hi, i);     //e.g., X = t(rbind(A,B,C)) %*% rbind(A,B,C) -> t(A)%*%A + t(B)%*%B + t(C)%*%C
 			if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
 				hi = fuseDatagenAndReorgOperation(hop, hi, i);    //e.g., t(rand(rows=10,cols=1)) -> rand(rows=1,cols=10), if one dim=1
-			hi = simplifyColwiseAggregate(hop, hi, i);        //e.g., colsums(X) -> sum(X) or X, if col/row vector
-			hi = simplifyRowwiseAggregate(hop, hi, i);        //e.g., rowsums(X) -> sum(X) or X, if row/col vector
+			hi = simplifyColwiseAggregate(hop, hi, i);        //e.g., colSums(X) -> sum(X) or X, if col/row vector
+			hi = simplifyRowwiseAggregate(hop, hi, i);        //e.g., rowSums(X) -> sum(X) or X, if row/col vector
+			hi = simplifyMeanAggregation(hop, hi, i);         //e.g., colSums(X)/N -> colMeans(X) if N = nrow(X)
 			hi = simplifyColSumsMVMult(hop, hi, i);           //e.g., colSums(X*Y) -> t(Y) %*% X, if Y col vector
 			hi = simplifyRowSumsMVMult(hop, hi, i);           //e.g., rowSums(X*Y) -> X %*% t(Y), if Y row vector
 			hi = simplifyUnnecessaryAggregate(hop, hi, i);    //e.g., sum(X) -> as.scalar(X), if 1x1 dims
 			hi = simplifyEmptyAggregate(hop, hi, i);          //e.g., sum(X) -> 0, if nnz(X)==0
-			hi = simplifyEmptyUnaryOperation(hop, hi, i);     //e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0			
+			hi = simplifyEmptyColMeans(hop, hi, i);           //e.g., colMeans(X-colMeans(X)) if none or scaling by scalars/col-vectors
+			hi = simplifyEmptyUnaryOperation(hop, hi, i);     //e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0
 			hi = simplifyEmptyReorgOperation(hop, hi, i);     //e.g., t(X) -> matrix(0, ncol(X), nrow(X)) 
 			hi = simplifyEmptySortOperation(hop, hi, i);      //e.g., order(X) -> seq(1, nrow(X)), if nnz(X)==0 
 			hi = simplifyEmptyMatrixMult(hop, hi, i);         //e.g., X%*%Y -> matrix(0,...), if nnz(Y)==0 | X if Y==matrix(1,1,1)
@@ -722,6 +724,30 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 		return hi;
 	}
 	
+	private static Hop simplifyMeanAggregation( Hop parent, Hop hi, int pos ) {
+		// colSums(X)/N -> colMeans(X), if N = nrow(X), all directions but different vals
+		if( HopRewriteUtils.isBinary(hi, OpOp2.DIV)
+			&& HopRewriteUtils.isAggUnaryOp(hi.getInput(0), AggOp.SUM)
+			&& hi.getInput(0).getParent().size()==1 //prevent repeated scans
+			&& hi.getInput(1).getDataType().isScalar())
+		{
+			AggUnaryOp agg = (AggUnaryOp)hi.getInput(0);
+			Hop in = agg.getInput(0);
+			Hop N = hi.getInput(1);
+			if( (agg.getDirection().isRow() && HopRewriteUtils.isSizeExpressionOf(N, in, false))
+				|| (agg.getDirection().isCol() && HopRewriteUtils.isSizeExpressionOf(N, in, true)) )
+			{
+				HopRewriteUtils.replaceChildReference(parent, hi, agg, pos);
+				HopRewriteUtils.cleanupUnreferenced(hi, N);
+				agg.setOp(AggOp.MEAN);
+				hi = agg;
+				LOG.debug("Applied simplifyMeanAggregation");
+			}
+		}
+		
+		return hi;
+	}
+	
 	private static Hop simplifyColSumsMVMult( Hop parent, Hop hi, int pos ) 
 	{
 		//colSums(X*Y) -> t(Y) %*% X, if Y col vector; additional transpose later
@@ -821,7 +847,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 	
 	private static Hop simplifyEmptyAggregate(Hop parent, Hop hi, int pos) 
 	{
-		if( hi instanceof AggUnaryOp  ) 
+		if( hi instanceof AggUnaryOp )
 		{
 			AggUnaryOp uhi = (AggUnaryOp)hi;
 			Hop input = uhi.getInput().get(0);
@@ -848,6 +874,30 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 		return hi;
 	}
 	
+	private static Hop simplifyEmptyColMeans(Hop parent, Hop hi, int pos) 
+	{
+		if( hi.dimsKnown() && HopRewriteUtils.isAggUnaryOp(hi, AggOp.MEAN, Direction.Col) ) {
+			Hop in = hi.getInput(0);
+			//colMeans(X-colMeans(X)) without scaling
+			boolean apply = HopRewriteUtils.isBinary(in, OpOp2.MINUS)
+				&& HopRewriteUtils.isAggUnaryOp(in.getInput(1), AggOp.MEAN, Direction.Col)
+				&& in.getInput(0) == in.getInput(1).getInput(0); //requires CSE
+			//colMeans((X-colMeans(X))/colSds(X)) if scaling by scalars/col-vectors
+			apply = apply || (HopRewriteUtils.isBinary(in, OpOp2.DIV, OpOp2.MULT)
+				&& in.getInput(1).getDim1()==1 //row vector
+				&& HopRewriteUtils.isBinary(in.getInput(0), OpOp2.MINUS)
+				&& HopRewriteUtils.isAggUnaryOp(in.getInput(0).getInput(1), AggOp.MEAN, Direction.Col)
+				&& in.getInput(0).getInput(0) == in.getInput(0).getInput(1).getInput(0));
+			if( apply ) {
+				Hop hnew = HopRewriteUtils.createDataGenOp(hi, hi, 0); //empty
+				HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
+				hi = hnew;
+				LOG.debug("Applied simplifyEmptyColMeans");
+			}
+		}
+		return hi;
+	}
+	
 	private static Hop simplifyEmptyUnaryOperation(Hop parent, Hop hi, int pos) 
 	{
 		if( hi instanceof UnaryOp  ) 
@@ -866,7 +916,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 					
 					LOG.debug("Applied simplifyEmptyUnaryOperation");
 				}
-			}			
+			}
 		}
 		
 		return hi;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index f0d9dea..56854ff 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -158,7 +158,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			hi = simplifyBinaryMatrixScalarOperation(hop, hi, i);//e.g., as.scalar(X*s) -> as.scalar(X)*s;
 			hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X))
 			hi = pushdownCSETransposeScalarOperation(hop, hi, i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
-			hi = pushdownSumBinaryMult(hop, hi, i);              //e.g., sum(lamda*X) -> lamda*sum(X)
+			hi = pushdownSumBinaryMult(hop, hi, i);              //e.g., sum(lambda*X) -> lambda*sum(X)
 			hi = simplifyUnaryPPredOperation(hop, hi, i);        //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
 			hi = simplifyTransposedAppend(hop, hi, i);           //e.g., t(cbind(t(A),t(B))) -> rbind(A,B);
 			if(OptimizerUtils.ALLOW_OPERATOR_FUSION)