You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2015/11/23 04:53:03 UTC
[8/8] incubator-systemml git commit: Merged simplification rewrites
(quaternary operators)
Merged simplification rewrites (quaternary operators)
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/a2f78e74
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/a2f78e74
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/a2f78e74
Branch: refs/heads/master
Commit: a2f78e74e51047144268205b0cbbe8f40cddf275
Parents: d70c452
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sun Nov 22 19:50:30 2015 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sun Nov 22 19:50:30 2015 -0800
----------------------------------------------------------------------
.../RewriteAlgebraicSimplificationDynamic.java | 6 +-
.../RewriteAlgebraicSimplificationStatic.java | 605 -------------------
2 files changed, 3 insertions(+), 608 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a2f78e74/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 1a0710f..5b1111a 100644
--- a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -1391,7 +1391,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
&& HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) //prevent mv
&& ((BinaryOp)bop.getInput().get(1)).getOp()==OpOp2.POW
&& bop.getInput().get(1).getInput().get(1) instanceof LiteralOp
- && HopRewriteUtils.getIntValue((LiteralOp)bop.getInput().get(1).getInput().get(1))==2)
+ && HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1).getInput().get(1))==2)
{
Hop W = bop.getInput().get(0);
Hop tmp = bop.getInput().get(1).getInput().get(0); //(X - U %*% t(V))
@@ -1447,7 +1447,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
if( !appliedPattern
&& bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp
- && HopRewriteUtils.getIntValue((LiteralOp)bop.getInput().get(1))==2
+ && HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1))==2
&& bop.getInput().get(0) instanceof BinaryOp
&& bop.getInput().get(0).getDataType()==DataType.MATRIX
&& ((BinaryOp)bop.getInput().get(0)).getOp()==OpOp2.MINUS
@@ -1502,7 +1502,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
if( !appliedPattern
&& bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp
- && HopRewriteUtils.getIntValue((LiteralOp)bop.getInput().get(1))==2
+ && HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1))==2
&& bop.getInput().get(0) instanceof BinaryOp
&& bop.getInput().get(0).getDataType()==DataType.MATRIX
&& ((BinaryOp)bop.getInput().get(0)).getOp()==OpOp2.MINUS
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/a2f78e74/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index d2887a2..4a7b2d4 100644
--- a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -1290,616 +1290,11 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
LOG.debug("Applied simplifyGroupedAggregateCount");
}
}
-<<<<<<< Upstream, based on branch 'master' of https://git-wip-us.apache.org/repos/asf/incubator-systemml.git
}
return hi;
}
-
- /**
- * Searches for weighted squared loss expressions and replaces them with a quaternary operator.
- * Currently, this search includes the following three patterns:
- * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
- * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
- * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
- *
- * NOTE: We include transpose into the pattern because during runtime we need to compute
- * U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
- * without additional memory requirements for internal transpose.
- *
- * This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
- * U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
- * rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
- * as we generalized the runtime or hop/lop compilation.
- *
- * @param parent
- * @param hi
- * @param pos
- * @return
- * @throws HopsException
- */
- private Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos)
- throws HopsException
- {
- //NOTE: there might be also a general simplification without custom operator
- //via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
- Hop hnew = null;
-
- if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
- && ((AggUnaryOp)hi).getOp() == AggOp.SUM //all patterns rooted by sum()
- && hi.getInput().get(0) instanceof BinaryOp //all patterns subrooted by binary op
- && hi.getInput().get(0).getDim2() > 1 ) //not applied for vector-vector mult
- {
- BinaryOp bop = (BinaryOp) hi.getInput().get(0);
- boolean appliedPattern = false;
-
- //Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
- //alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
- if( bop.getOp()==OpOp2.MULT && bop.getInput().get(1) instanceof BinaryOp
- && bop.getInput().get(0).getDataType()==DataType.MATRIX
- && HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) //prevent mv
- && ((BinaryOp)bop.getInput().get(1)).getOp()==OpOp2.POW
- && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp
- && HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1).getInput().get(1))==2)
- {
- Hop W = bop.getInput().get(0);
- Hop tmp = bop.getInput().get(1).getInput().get(0); //(X - U %*% t(V))
-
- if( tmp instanceof BinaryOp && ((BinaryOp)tmp).getOp()==OpOp2.MINUS
- && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv
- && tmp.getInput().get(0).getDataType() == DataType.MATRIX )
- {
- //a) sum (W * (X - U %*% t(V)) ^ 2)
- int uvIndex = -1;
- if( tmp.getInput().get(1) instanceof AggBinaryOp //ba gurantees matrices
- && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT
- {
- uvIndex = 1;
- }
- //b) sum (W * (U %*% t(V) - X) ^ 2)
- else if(tmp.getInput().get(0) instanceof AggBinaryOp //ba gurantees matrices
- && HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT
- {
- uvIndex = 0;
- }
-
- if( uvIndex >= 0 ) //rewrite match
- {
- Hop X = tmp.getInput().get((uvIndex==0)?1:0);
- Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
- Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
-
- if( !HopRewriteUtils.isTransposeOperation(V) ) {
- V = HopRewriteUtils.createTranspose(V);
- }
- else{
- V = V.getInput().get(0);
- }
-
- //handle special case of post_nz
- if( HopRewriteUtils.isNonZeroIndicator(W, X) ){
- W = new LiteralOp(1);
- }
-
- //construct quaternary hop
- hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE,
- OpOp4.WSLOSS, X, U, V, W, true);
- HopRewriteUtils.setOutputParametersForScalar(hnew);
-
- appliedPattern = true;
- LOG.debug("Applied simplifyWeightedSquaredLoss1"+uvIndex+" (line "+hi.getBeginLine()+")");
- }
- }
- }
-
- //Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
- //alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
- if( !appliedPattern
- && bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp
- && HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1))==2
- && bop.getInput().get(0) instanceof BinaryOp
- && bop.getInput().get(0).getDataType()==DataType.MATRIX
- && ((BinaryOp)bop.getInput().get(0)).getOp()==OpOp2.MINUS
- && HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv
- && bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX)
- {
- Hop lleft = bop.getInput().get(0).getInput().get(0);
- Hop lright = bop.getInput().get(0).getInput().get(1);
-
- //a) sum ((X - W * (U %*% t(V))) ^ 2)
- int wuvIndex = -1;
- if( lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp ){
- wuvIndex = 1;
- }
- //b) sum ((W * (U %*% t(V)) - X) ^ 2)
- else if( lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp ){
- wuvIndex = 0;
- }
-
- if( wuvIndex >= 0 ) //rewrite match
- {
- Hop X = bop.getInput().get(0).getInput().get((wuvIndex==0)?1:0);
- Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex); //(W * (U %*% t(V)))
-
- if( ((BinaryOp)tmp).getOp()==OpOp2.MULT
- && tmp.getInput().get(0).getDataType() == DataType.MATRIX
- && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv
- && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT
- {
- Hop W = tmp.getInput().get(0);
- Hop U = tmp.getInput().get(1).getInput().get(0);
- Hop V = tmp.getInput().get(1).getInput().get(1);
-
- if( !HopRewriteUtils.isTransposeOperation(V) ) {
- V = HopRewriteUtils.createTranspose(V);
- }
- else {
- V = V.getInput().get(0);
- }
-
- hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE,
- OpOp4.WSLOSS, X, U, V, W, false);
- HopRewriteUtils.setOutputParametersForScalar(hnew);
-
- appliedPattern = true;
- LOG.debug("Applied simplifyWeightedSquaredLoss2"+wuvIndex+" (line "+hi.getBeginLine()+")");
- }
- }
- }
-
- //Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
- //alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
- if( !appliedPattern
- && bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp
- && HopRewriteUtils.getDoubleValue((LiteralOp)bop.getInput().get(1))==2
- && bop.getInput().get(0) instanceof BinaryOp
- && bop.getInput().get(0).getDataType()==DataType.MATRIX
- && ((BinaryOp)bop.getInput().get(0)).getOp()==OpOp2.MINUS
- && HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv
- && bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX)
- {
- Hop lleft = bop.getInput().get(0).getInput().get(0);
- Hop lright = bop.getInput().get(0).getInput().get(1);
-
- //a) sum ((X - (U %*% t(V))) ^ 2)
- int uvIndex = -1;
- if( lright instanceof AggBinaryOp //ba gurantees matrices
- && HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
- {
- uvIndex = 1;
- }
- //b) sum (((U %*% t(V)) - X) ^ 2)
- else if( lleft instanceof AggBinaryOp //ba gurantees matrices
- && HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
- {
- uvIndex = 0;
- }
-
- if( uvIndex >= 0 ) //rewrite match
- {
- Hop X = bop.getInput().get(0).getInput().get((uvIndex==0)?1:0);
- Hop tmp = bop.getInput().get(0).getInput().get(uvIndex); //(U %*% t(V))
- Hop W = new LiteralOp(1); //no weighting
- Hop U = tmp.getInput().get(0);
- Hop V = tmp.getInput().get(1);
-
- if( !HopRewriteUtils.isTransposeOperation(V) ) {
- V = HopRewriteUtils.createTranspose(V);
- }
- else {
- V = V.getInput().get(0);
- }
-
- hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE,
- OpOp4.WSLOSS, X, U, V, W, false);
- HopRewriteUtils.setOutputParametersForScalar(hnew);
-
- appliedPattern = true;
- LOG.debug("Applied simplifyWeightedSquaredLoss3"+uvIndex+" (line "+hi.getBeginLine()+")");
- }
- }
- }
-
- //relink new hop into original position
- if( hnew != null ) {
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
- hi = hnew;
- }
-
- return hi;
- }
-
- /**
- *
- * @param parent
- * @param hi
- * @param pos
- * @return
- * @throws HopsException
- */
- private Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos)
- throws HopsException
- {
- Hop hnew = null;
- if( hi instanceof BinaryOp //all patterns subrooted by W *
- && ((BinaryOp) hi).getOp()==OpOp2.MULT
- && hi.getDim2() > 1 //not applied for vector-vector mult
- && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv
- && hi.getInput().get(0).getDataType()==DataType.MATRIX
- && hi.getInput().get(1) instanceof UnaryOp ) //sigmoid/log
- {
- UnaryOp uop = (UnaryOp) hi.getInput().get(1);
- boolean appliedPattern = false;
-
- //Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
- if( uop.getOp() == OpOp1.SIGMOID
- && uop.getInput().get(0) instanceof AggBinaryOp
- && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0),true) )
- {
- Hop W = hi.getInput().get(0);
- Hop Y = uop.getInput().get(0).getInput().get(0);
- Hop tX = uop.getInput().get(0).getInput().get(1);
-
- if( !HopRewriteUtils.isTransposeOperation(tX) ) {
- tX = HopRewriteUtils.createTranspose(tX);
- }
- else
- tX = tX.getInput().get(0);
-
- hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
- OpOp4.WSIGMOID, W, Y, tX, false, false);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
-
- appliedPattern = true;
- LOG.debug("Applied simplifyWeightedSigmoid1 (line "+hi.getBeginLine()+")");
- }
-
- //Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
- if( !appliedPattern
- && uop.getOp() == OpOp1.SIGMOID
- && uop.getInput().get(0) instanceof BinaryOp
- && ((BinaryOp)uop.getInput().get(0)).getOp()==OpOp2.MINUS
- && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp
- && HopRewriteUtils.getDoubleValueSafe(
- (LiteralOp)uop.getInput().get(0).getInput().get(0))==0
- && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp
- && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0),true))
- {
- Hop W = hi.getInput().get(0);
- Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0);
- Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1);
-
- if( !HopRewriteUtils.isTransposeOperation(tX) ) {
- tX = HopRewriteUtils.createTranspose(tX);
- }
- else
- tX = tX.getInput().get(0);
-
- hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
- OpOp4.WSIGMOID, W, Y, tX, false, true);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
-
- appliedPattern = true;
- LOG.debug("Applied simplifyWeightedSigmoid2 (line "+hi.getBeginLine()+")");
- }
-
- //Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
- if( !appliedPattern
- && uop.getOp() == OpOp1.LOG
- && uop.getInput().get(0) instanceof UnaryOp
- && ((UnaryOp)uop.getInput().get(0)).getOp() == OpOp1.SIGMOID
- && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp
- && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0),true) )
- {
- Hop W = hi.getInput().get(0);
- Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0);
- Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1);
-
- if( !HopRewriteUtils.isTransposeOperation(tX) ) {
- tX = HopRewriteUtils.createTranspose(tX);
- }
- else
- tX = tX.getInput().get(0);
-
- hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
- OpOp4.WSIGMOID, W, Y, tX, true, false);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
-
- appliedPattern = true;
- LOG.debug("Applied simplifyWeightedSigmoid3 (line "+hi.getBeginLine()+")");
- }
-
- //Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
- if( !appliedPattern
- && uop.getOp() == OpOp1.LOG
- && uop.getInput().get(0) instanceof UnaryOp
- && ((UnaryOp)uop.getInput().get(0)).getOp() == OpOp1.SIGMOID
- && uop.getInput().get(0).getInput().get(0) instanceof BinaryOp )
- {
- BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
-
- if( bop.getOp() == OpOp2.MINUS
- && bop.getInput().get(0) instanceof LiteralOp
- && HopRewriteUtils.getDoubleValueSafe((LiteralOp)bop.getInput().get(0))==0
- && bop.getInput().get(1) instanceof AggBinaryOp
- && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0),true))
- {
- Hop W = hi.getInput().get(0);
- Hop Y = bop.getInput().get(1).getInput().get(0);
- Hop tX = bop.getInput().get(1).getInput().get(1);
-
- if( !HopRewriteUtils.isTransposeOperation(tX) ) {
- tX = HopRewriteUtils.createTranspose(tX);
- }
- else
- tX = tX.getInput().get(0);
-
- hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
- OpOp4.WSIGMOID, W, Y, tX, true, true);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
-
- appliedPattern = true;
- LOG.debug("Applied simplifyWeightedSigmoid4 (line "+hi.getBeginLine()+")");
- }
- }
- }
-
- //relink new hop into original position
- if( hnew != null ) {
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
- hi = hnew;
- }
-
- return hi;
- }
-
- /**
- *
- * @param parent
- * @param hi
- * @param pos
- * @return
- * @throws HopsException
- */
- private Hop simplifyWeightedDivMM(Hop parent, Hop hi, int pos)
- throws HopsException
- {
- Hop hnew = null;
- boolean appliedPattern = false;
-
- //left/right patterns rooted by 'ab - b(div)' or 'ab - b(mult)'
- //note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are vectors (see mmchain ops)
- if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply()
- && (hi.getInput().get(0) instanceof BinaryOp
- && HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(0)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)
- || hi.getInput().get(1) instanceof BinaryOp
- && hi.getDim2() > 1 //not applied for vector-vector mult
- && HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(1)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)) )
- {
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
-
- //Pattern 1) t(U) %*% (W/(U%*%t(V)))
- //alternative pattern: t(U) %*% (W*(U%*%t(V)))
- if( right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)right).getOp(),LOOKUP_VALID_WDIVMM_BINARY)
- && HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) //prevent mv
- && right.getInput().get(1) instanceof AggBinaryOp
- && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
- {
- Hop W = right.getInput().get(0);
- Hop U = right.getInput().get(1).getInput().get(0);
- Hop V = right.getInput().get(1).getInput().get(1);
-
- if( HopRewriteUtils.isTransposeOfItself(left, U) )
- {
- if( !HopRewriteUtils.isTransposeOperation(V) )
- V = HopRewriteUtils.createTranspose(V);
- else
- V = V.getInput().get(0);
-
- boolean mult = ((BinaryOp)right).getOp() == OpOp2.MULT;
- hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
- OpOp4.WDIVMM, W, U, V, 1, mult, false);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
-
- //add output transpose for efficient target indexing (redundant t() removed by other rewrites)
- hnew = HopRewriteUtils.createTranspose(hnew);
-
- appliedPattern = true;
- LOG.debug("Applied simplifyWeightedDivMM1 (line "+hi.getBeginLine()+")");
- }
- }
-
- //Pattern 2) (W/(U%*%t(V))) %*% V
- //alternative pattern: (W*(U%*%t(V))) %*% V
- if( !appliedPattern
- && left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)left).getOp(), LOOKUP_VALID_WDIVMM_BINARY)
- && HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) //prevent mv
- && left.getInput().get(1) instanceof AggBinaryOp
- && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
- {
- Hop W = left.getInput().get(0);
- Hop U = left.getInput().get(1).getInput().get(0);
- Hop V = left.getInput().get(1).getInput().get(1);
-
- if( HopRewriteUtils.isTransposeOfItself(right, V) )
- {
- if( !HopRewriteUtils.isTransposeOperation(V) )
- V = right;
- else
- V = V.getInput().get(0);
-
- boolean mult = ((BinaryOp)left).getOp() == OpOp2.MULT;
- hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
- OpOp4.WDIVMM, W, U, V, 2, mult, false);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
-
- appliedPattern = true;
- LOG.debug("Applied simplifyWeightedDivMM2 (line "+hi.getBeginLine()+")");
- }
- }
-
- //Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X))
- if( right instanceof BinaryOp && ((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
- && right.getInput().get(1) instanceof BinaryOp && ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS
- && right.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
- && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
- && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
- {
- Hop W = right.getInput().get(0);
- Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
- Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
- Hop X = right.getInput().get(1).getInput().get(1);
-
- if( HopRewriteUtils.isNonZeroIndicator(W, X) //W-X constraint
- && HopRewriteUtils.isTransposeOfItself(left, U) ) //t(U)-U constraint
- {
- if( !HopRewriteUtils.isTransposeOperation(V) )
- V = HopRewriteUtils.createTranspose(V);
- else
- V = V.getInput().get(0);
-
- hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
- OpOp4.WDIVMM, X, U, V, 1, true, true);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
-
- //add output transpose for efficient target indexing (redundant t() removed by other rewrites)
- hnew = HopRewriteUtils.createTranspose(hnew);
-
- appliedPattern = true;
- LOG.debug("Applied simplifyWeightedDivMM3 (line "+hi.getBeginLine()+")");
- }
- }
-
- //Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V
- if( !appliedPattern
- && left instanceof BinaryOp && ((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
- && left.getInput().get(1) instanceof BinaryOp && ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS
- && left.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
- && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
- && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
- {
- Hop W = left.getInput().get(0);
- Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
- Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
- Hop X = left.getInput().get(1).getInput().get(1);
-
- if( HopRewriteUtils.isNonZeroIndicator(W, X) //W-X constraint
- && HopRewriteUtils.isTransposeOfItself(right, V) ) //V-t(V) constraint
- {
- if( !HopRewriteUtils.isTransposeOperation(V) )
- V = right;
- else
- V = V.getInput().get(0);
-
- hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
- OpOp4.WDIVMM, X, U, V, 2, true, true);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
-
- appliedPattern = true;
- LOG.debug("Applied simplifyWeightedDivMM4 (line "+hi.getBeginLine()+")");
- }
- }
- }
-
- //Pattern 5) (W*(U%*%t(V)))
- if( !appliedPattern
- && hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
- && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv
- && hi.getDim2() > 1 //not applied for vector-vector mult
- && hi.getInput().get(0).getDataType() == DataType.MATRIX
- && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock()
- && hi.getInput().get(1) instanceof AggBinaryOp
- && (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
- && HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
- {
- Hop W = hi.getInput().get(0);
- Hop U = hi.getInput().get(1).getInput().get(0);
- Hop V = hi.getInput().get(1).getInput().get(1);
-
- if( !HopRewriteUtils.isTransposeOperation(V) )
- V = HopRewriteUtils.createTranspose(V);
- else
- V = V.getInput().get(0);
-
- hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
- OpOp4.WDIVMM, W, U, V, 0, true, false);
- HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
-
- appliedPattern = true;
- LOG.debug("Applied simplifyWeightedDivMM5 (line "+hi.getBeginLine()+")");
- }
-
- //relink new hop into original position
- if( hnew != null ) {
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
- hi = hnew;
- }
-
- return hi;
- }
-
- /**
- *
- * @param parent
- * @param hi
- * @param pos
- * @return
- * @throws HopsException
- */
- private Hop simplifyWeightedCrossEntropy(Hop parent, Hop hi, int pos)
- throws HopsException
- {
- Hop hnew = null;
-
- //Pattern 1) sum( X * log(U %*% t(V)))
- if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
- && ((AggUnaryOp)hi).getOp() == AggOp.SUM //pattern rooted by sum()
- && hi.getInput().get(0) instanceof BinaryOp //pattern subrooted by binary op
- && hi.getInput().get(0).getDim2() > 1 ) //not applied for vector-vector mult
- {
- BinaryOp bop = (BinaryOp) hi.getInput().get(0);
- Hop left = bop.getInput().get(0);
- Hop right = bop.getInput().get(1);
-
- if( bop.getOp()==OpOp2.MULT && left.getDataType()==DataType.MATRIX
- && HopRewriteUtils.isEqualSize(left, right) //prevent mb
- && right instanceof UnaryOp && ((UnaryOp)right).getOp()==OpOp1.LOG
- && right.getInput().get(0) instanceof AggBinaryOp //ba gurantees matrices
- && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT
- {
- Hop X = left;
- Hop U = right.getInput().get(0).getInput().get(0);
- Hop V = right.getInput().get(0).getInput().get(1);
-
- if( !HopRewriteUtils.isTransposeOperation(V) )
- V = HopRewriteUtils.createTranspose(V);
- else
- V = V.getInput().get(0);
-
- hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V);
- HopRewriteUtils.setOutputBlocksizes(hnew, X.getRowsInBlock(), X.getColsInBlock());
-
- LOG.debug("Applied simplifyWeightedCEMM (line "+hi.getBeginLine()+")");
- }
- }
-
- //relink new hop into original position
- if( hnew != null ) {
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, hnew, pos);
- hi = hnew;
-=======
->>>>>>> 04aa86c Fix rewrite 'fuse sum_sq' (after wsloss rewrite), for kmeans_predict
- }
-
- return hi;
- }
-
/**
*
* @param parent