You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/07/29 18:56:15 UTC
[systemds] branch master updated: [SYSTEMDS-3076] Additional hop
rewrites for colMeans sequences
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 5a8c979 [SYSTEMDS-3076] Additional hop rewrites for colMeans sequences
5a8c979 is described below
commit 5a8c979bdadc8285a63979e5894d80cf6d94dcd8
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Thu Jul 29 20:34:32 2021 +0200
[SYSTEMDS-3076] Additional hop rewrites for colMeans sequences
This patch adds two new rewrites that help remove unnecessary operations
in PCA with shifting (see functions/compress/WorkloadAlgorithmTest):
1) colSums(X) / N -> colMeans(X) (precondition and fewer ops)
2) colMeans((X-colMeans(X))/...) -> matrix(0,1,ncol(X))
After these rewrites have been applied, various additional rewrites
trigger to remove unnecessary terms with empty matrix multiplications.
---
src/main/java/org/apache/sysds/common/Types.java | 7 ++-
.../RewriteAlgebraicSimplificationDynamic.java | 60 ++++++++++++++++++++--
.../RewriteAlgebraicSimplificationStatic.java | 2 +-
3 files changed, 62 insertions(+), 7 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java
index badf307..da53091 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -144,7 +144,12 @@ public class Types
RowCol, // full aggregate
Row, // row aggregate (e.g., rowSums)
Col; // column aggregate (e.g., colSums)
-
+ public boolean isRow() {
+ return this == Row;
+ }
+ public boolean isCol() {
+ return this == Col;
+ }
@Override
public String toString() {
switch(this) {
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 269050c..5a29e6b 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -160,13 +160,15 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hi = removeUnnecessaryAppendTSMM(hop, hi, i); //e.g., X = t(rbind(A,B,C)) %*% rbind(A,B,C) -> t(A)%*%A + t(B)%*%B + t(C)%*%C
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
hi = fuseDatagenAndReorgOperation(hop, hi, i); //e.g., t(rand(rows=10,cols=1)) -> rand(rows=1,cols=10), if one dim=1
- hi = simplifyColwiseAggregate(hop, hi, i); //e.g., colsums(X) -> sum(X) or X, if col/row vector
- hi = simplifyRowwiseAggregate(hop, hi, i); //e.g., rowsums(X) -> sum(X) or X, if row/col vector
+ hi = simplifyColwiseAggregate(hop, hi, i); //e.g., colSums(X) -> sum(X) or X, if col/row vector
+ hi = simplifyRowwiseAggregate(hop, hi, i); //e.g., rowSums(X) -> sum(X) or X, if row/col vector
+ hi = simplifyMeanAggregation(hop, hi, i); //e.g., colSums(X)/N -> colMeans(X) if N = nrow(X)
hi = simplifyColSumsMVMult(hop, hi, i); //e.g., colSums(X*Y) -> t(Y) %*% X, if Y col vector
hi = simplifyRowSumsMVMult(hop, hi, i); //e.g., rowSums(X*Y) -> X %*% t(Y), if Y row vector
hi = simplifyUnnecessaryAggregate(hop, hi, i); //e.g., sum(X) -> as.scalar(X), if 1x1 dims
hi = simplifyEmptyAggregate(hop, hi, i); //e.g., sum(X) -> 0, if nnz(X)==0
- hi = simplifyEmptyUnaryOperation(hop, hi, i); //e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0
+ hi = simplifyEmptyColMeans(hop, hi, i); //e.g., colMeans(X-colMeans(X)) if none or scaling by scalars/col-vectors
+ hi = simplifyEmptyUnaryOperation(hop, hi, i); //e.g., round(X) -> matrix(0,nrow(X),ncol(X)), if nnz(X)==0
hi = simplifyEmptyReorgOperation(hop, hi, i); //e.g., t(X) -> matrix(0, ncol(X), nrow(X))
hi = simplifyEmptySortOperation(hop, hi, i); //e.g., order(X) -> seq(1, nrow(X)), if nnz(X)==0
hi = simplifyEmptyMatrixMult(hop, hi, i); //e.g., X%*%Y -> matrix(0,...), if nnz(Y)==0 | X if Y==matrix(1,1,1)
@@ -722,6 +724,30 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
return hi;
}
+ private static Hop simplifyMeanAggregation( Hop parent, Hop hi, int pos ) {
+ // colSums(X)/N -> colMeans(X), if N = nrow(X), all directions but different vals
+ if( HopRewriteUtils.isBinary(hi, OpOp2.DIV)
+ && HopRewriteUtils.isAggUnaryOp(hi.getInput(0), AggOp.SUM)
+ && hi.getInput(0).getParent().size()==1 //prevent repeated scans
+ && hi.getInput(1).getDataType().isScalar())
+ {
+ AggUnaryOp agg = (AggUnaryOp)hi.getInput(0);
+ Hop in = agg.getInput(0);
+ Hop N = hi.getInput(1);
+ if( (agg.getDirection().isRow() && HopRewriteUtils.isSizeExpressionOf(N, in, false))
+ || (agg.getDirection().isCol() && HopRewriteUtils.isSizeExpressionOf(N, in, true)) )
+ {
+ HopRewriteUtils.replaceChildReference(parent, hi, agg, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi, N);
+ agg.setOp(AggOp.MEAN);
+ hi = agg;
+ LOG.debug("Applied simplifyMeanAggregation");
+ }
+ }
+
+ return hi;
+ }
+
private static Hop simplifyColSumsMVMult( Hop parent, Hop hi, int pos )
{
//colSums(X*Y) -> t(Y) %*% X, if Y col vector; additional transpose later
@@ -821,7 +847,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
private static Hop simplifyEmptyAggregate(Hop parent, Hop hi, int pos)
{
- if( hi instanceof AggUnaryOp )
+ if( hi instanceof AggUnaryOp )
{
AggUnaryOp uhi = (AggUnaryOp)hi;
Hop input = uhi.getInput().get(0);
@@ -848,6 +874,30 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
return hi;
}
+ private static Hop simplifyEmptyColMeans(Hop parent, Hop hi, int pos)
+ {
+ if( hi.dimsKnown() && HopRewriteUtils.isAggUnaryOp(hi, AggOp.MEAN, Direction.Col) ) {
+ Hop in = hi.getInput(0);
+ //colMeans(X-colMeans(X)) without scaling
+ boolean apply = HopRewriteUtils.isBinary(in, OpOp2.MINUS)
+ && HopRewriteUtils.isAggUnaryOp(in.getInput(1), AggOp.MEAN, Direction.Col)
+ && in.getInput(0) == in.getInput(1).getInput(0); //requires CSE
+ //colMeans((X-colMeans(X))/colSds(X)) if scaling by scalars/col-vectors
+ apply = apply || (HopRewriteUtils.isBinary(in, OpOp2.DIV, OpOp2.MULT)
+ && in.getInput(1).getDim1()==1 //row vector
+ && HopRewriteUtils.isBinary(in.getInput(0), OpOp2.MINUS)
+ && HopRewriteUtils.isAggUnaryOp(in.getInput(0).getInput(1), AggOp.MEAN, Direction.Col)
+ && in.getInput(0).getInput(0) == in.getInput(0).getInput(1).getInput(0));
+ if( apply ) {
+ Hop hnew = HopRewriteUtils.createDataGenOp(hi, hi, 0); //empty
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew, pos);
+ hi = hnew;
+ LOG.debug("Applied simplifyEmptyColMeans");
+ }
+ }
+ return hi;
+ }
+
private static Hop simplifyEmptyUnaryOperation(Hop parent, Hop hi, int pos)
{
if( hi instanceof UnaryOp )
@@ -866,7 +916,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
LOG.debug("Applied simplifyEmptyUnaryOperation");
}
- }
+ }
}
return hi;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index f0d9dea..56854ff 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -158,7 +158,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
hi = simplifyBinaryMatrixScalarOperation(hop, hi, i);//e.g., as.scalar(X*s) -> as.scalar(X)*s;
hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X))
hi = pushdownCSETransposeScalarOperation(hop, hi, i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
- hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lamda*X) -> lamda*sum(X)
+ hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lambda*X) -> lambda*sum(X)
hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
hi = simplifyTransposedAppend(hop, hi, i); //e.g., t(cbind(t(A),t(B))) -> rbind(A,B);
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)