You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/02/12 19:32:18 UTC
[4/5] incubator-systemml git commit: [SYSTEMML-1254] New sum-product
rewrites (agg pushdown), stratstats
[SYSTEMML-1254] New sum-product rewrites (agg pushdown), stratstats
In the spirit of our SPOOF compiler framework and the existing
sum(X%*%Y) rewrite, this patch adds the following two sum-product
rewrites (where the first applies multiple times in stratstats):
* colSums(X %*% Y) -> colsSums(X) %*% Y
* rowSums(X %*% Y) -> X %*% rowSums(Y)
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b3ba9916
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b3ba9916
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b3ba9916
Branch: refs/heads/master
Commit: b3ba991604cf79c5e3e2c0992fe2439ae47ce023
Parents: d3e617b
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun Feb 12 08:29:36 2017 +0100
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sun Feb 12 09:42:03 2017 +0100
----------------------------------------------------------------------
.../RewriteAlgebraicSimplificationDynamic.java | 54 +++++++++++---------
1 file changed, 30 insertions(+), 24 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b3ba9916/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index e8e3862..6ffcbd5 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2547,11 +2547,12 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
private Hop simplifySumMatrixMult(Hop parent, Hop hi, int pos)
{
- //sum(A%*%B) -> sum(t(colSums(A))*rowSums(B))
- //if not dot product, not applied since aggregate removed
- //if sum not the only consumer, not applied to prevent redundancy
+ //sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product
+ //colSums(A%*%B) -> colSums(A)%*%B
+ //rowSums(A%*%B) -> A%*%rowSums(B)
+ //-- if not dot product, not applied since aggregate removed
+ //-- if sum not the only consumer, not applied to prevent redundancy
if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getOp()==AggOp.SUM //sum
- && ((AggUnaryOp)hi).getDirection() == Direction.RowCol //full aggregate
&& hi.getInput().get(0) instanceof AggBinaryOp //A%*%B
&& (hi.getInput().get(0).getDim1()>1 || hi.getInput().get(0).getDim2()>1) //not dot product
&& hi.getInput().get(0).getParent().size()==1 ) //not multiple consumers of matrix mult
@@ -2560,34 +2561,39 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
Hop left = hi2.getInput().get(0);
Hop right = hi2.getInput().get(1);
- //remove link from parent to diag
+ //remove link from parent to matrix mult
HopRewriteUtils.removeChildReference(hi, hi2);
//create new operators
- AggUnaryOp colSum = new AggUnaryOp(left.getName(), left.getDataType(), left.getValueType(), AggOp.SUM, Direction.Col, left);
- colSum.setRowsInBlock(left.getRowsInBlock());
- colSum.setColsInBlock(left.getColsInBlock());
- colSum.refreshSizeInformation();
- ReorgOp trans = HopRewriteUtils.createTranspose(colSum);
- AggUnaryOp rowSum = new AggUnaryOp(right.getName(), right.getDataType(), right.getValueType(), AggOp.SUM, Direction.Row, right);
- rowSum.setRowsInBlock(right.getRowsInBlock());
- rowSum.setColsInBlock(right.getColsInBlock());
- rowSum.refreshSizeInformation();
- BinaryOp mult = new BinaryOp(right.getName(), right.getDataType(), right.getValueType(), OpOp2.MULT, trans, rowSum);
- mult.setRowsInBlock(right.getRowsInBlock());
- mult.setColsInBlock(right.getColsInBlock());
- mult.refreshSizeInformation();
-
+ Hop root = null;
+ //pattern: sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), later rewritten to dot-product
+ if( ((AggUnaryOp)hi).getDirection() == Direction.RowCol ) {
+ AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
+ ReorgOp trans = HopRewriteUtils.createTranspose(colSum);
+ AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
+ root = HopRewriteUtils.createBinary(trans, rowSum, OpOp2.MULT);
+ LOG.debug("Applied simplifySumMatrixMult RC.");
+ }
+ //colSums(A%*%B) -> colSums(A)%*%B
+ else if( ((AggUnaryOp)hi).getDirection() == Direction.Col ) {
+ AggUnaryOp colSum = HopRewriteUtils.createAggUnaryOp(left, AggOp.SUM, Direction.Col);
+ root = HopRewriteUtils.createMatrixMultiply(colSum, right);
+ LOG.debug("Applied simplifySumMatrixMult C.");
+ }
+ //rowSums(A%*%B) -> A%*%rowSums(B)
+ else if( ((AggUnaryOp)hi).getDirection() == Direction.Row ) {
+ AggUnaryOp rowSum = HopRewriteUtils.createAggUnaryOp(right, AggOp.SUM, Direction.Row);
+ root = HopRewriteUtils.createMatrixMultiply(left, rowSum);
+ LOG.debug("Applied simplifySumMatrixMult R.");
+ }
//rehang new subdag under current node (keep hi intact)
- HopRewriteUtils.addChildReference(hi, mult, 0);
+ HopRewriteUtils.addChildReference(hi, root, 0);
hi.refreshSizeInformation();
-
+
//cleanup if only consumer of intermediate
if( hi2.getParent().isEmpty() )
- HopRewriteUtils.removeAllChildReferences( hi2 );
-
- LOG.debug("Applied simplifySumMatrixMult.");
+ HopRewriteUtils.removeAllChildReferences( hi2 );
}
return hi;