You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2015/11/23 04:53:00 UTC
[5/8] incubator-systemml git commit: Fix rewrite 'fuse sum_sq' (after
wsloss rewrite), for kmeans_predict
Fix rewrite 'fuse sum_sq' (after wsloss rewrite), for kmeans_predict
The new rewrite sum(X^2) -> sum_sq(X) was mistakenly applied before
sum((X-L%*%R)^2) -> wsloss(X,L,R) and hence led to a performance
regression due to shuffle for the join between X and L%*%R. This change
moved all quaternary rewrites to the category of 'dynamic' rewrites
because they anyway check for various size parameters. By applying the
quaternary rewrites before the 'fuse sum_sq', both rewrites trigger for
the intended patterns. On a kmeans 80GB use case, this change reduced
the end-to-end runtime from 590s to 114s.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/e52e0c0a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/e52e0c0a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/e52e0c0a
Branch: refs/heads/master
Commit: e52e0c0a6b5a39e375383b046e3ce0465ff8d662
Parents: cc4aae7
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sat Nov 21 00:51:27 2015 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sun Nov 22 19:38:23 2015 -0800
----------------------------------------------------------------------
.../RewriteAlgebraicSimplificationDynamic.java | 616 ++++++++++++++++++-
.../RewriteAlgebraicSimplificationStatic.java | 11 +-
2 files changed, 616 insertions(+), 11 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e52e0c0a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 5c7a0fb..10dbedf 100644
--- a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -28,10 +28,12 @@ import com.ibm.bi.dml.hops.AggUnaryOp;
import com.ibm.bi.dml.hops.BinaryOp;
import com.ibm.bi.dml.hops.DataGenOp;
import com.ibm.bi.dml.hops.Hop;
+import com.ibm.bi.dml.hops.QuaternaryOp;
import com.ibm.bi.dml.hops.Hop.AggOp;
import com.ibm.bi.dml.hops.Hop.DataGenMethod;
import com.ibm.bi.dml.hops.Hop.Direction;
import com.ibm.bi.dml.hops.Hop.OpOp1;
+import com.ibm.bi.dml.hops.Hop.OpOp4;
import com.ibm.bi.dml.hops.Hop.ReOrgOp;
import com.ibm.bi.dml.hops.HopsException;
import com.ibm.bi.dml.hops.IndexingOp;
@@ -40,6 +42,7 @@ import com.ibm.bi.dml.hops.LiteralOp;
import com.ibm.bi.dml.hops.Hop.OpOp2;
import com.ibm.bi.dml.hops.ReorgOp;
import com.ibm.bi.dml.hops.UnaryOp;
+import com.ibm.bi.dml.lops.MapMultChain.ChainType;
import com.ibm.bi.dml.parser.DMLTranslator;
import com.ibm.bi.dml.parser.DataExpression;
import com.ibm.bi.dml.parser.Expression.DataType;
@@ -56,10 +59,8 @@ import com.ibm.bi.dml.parser.Expression.ValueType;
*/
public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
{
-
private static final Log LOG = LogFactory.getLog(RewriteAlgebraicSimplificationDynamic.class.getName());
-
//valid aggregation operation types for rowOp to Op conversions (not all operations apply)
private static AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN};
@@ -70,6 +71,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
//valid unary operation types for empty (sparse-safe) operations (not all operations apply)
private static OpOp1[] LOOKUP_VALID_EMPTY_UNARY = new OpOp1[]{OpOp1.ABS, OpOp1.SIN, OpOp1.TAN, OpOp1.SQRT, OpOp1.ROUND, OpOp1.CUMSUM};
+ //valid pseudo-sparse-safe binary operators for wdivmm
+ private static OpOp2[] LOOKUP_VALID_WDIVMM_BINARY = new OpOp2[]{OpOp2.MULT, OpOp2.DIV};
@Override
@@ -159,10 +162,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hi = simplifyDiagMatrixMult(hop, hi, i); //e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector
hi = simplifySumDiagToTrace(hi); //e.g., sum(diag(X)) -> trace(X); if col vector
hi = pushdownBinaryOperationOnDiag(hop, hi, i); //e.g., diag(X)*7 -> diag(X*7); if col vector
+ hi = simplifyWeightedSquaredLoss(hop, hi, i); //e.g., sum(W * (X - U %*% t(V)) ^ 2) -> wsl(X, U, t(V), W, true),
+ hi = simplifyWeightedSigmoidMMChains(hop, hi, i); //e.g., W * sigmoid(Y%*%t(X)) -> wsigmoid(W, Y, t(X), type)
+ hi = simplifyWeightedDivMM(hop, hi, i); //e.g., t(U) %*% (X/(U%*%t(V))) -> wdivmm(X, U, t(V), left)
+ hi = simplifyWeightedCrossEntropy(hop, hi, i); //e.g., sum(X*log(U%*%t(V))) -> wcemm(X, U, t(V))
hi = simplifyDotProductSum(hop, hi, i); //e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1
hi = fuseSumSquared(hop, hi, i); //e.g., sum(X^2) -> sumSq(X), if ncol(X)>1
hi = reorderMinusMatrixMult(hop, hi, i); //e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size
- hi = simplifySumMatrixMult(hop, hi, i); //e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product
+ hi = simplifySumMatrixMult(hop, hi, i); //e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product / wsloss
hi = simplifyEmptyBinaryOperation(hop, hi, i); //e.g., X*Y -> matrix(0,nrow(X), ncol(X)) / X+Y->X / X-Y -> X
hi = simplifyScalarMVBinaryOperation(hi); //e.g., X*y -> X*as.scalar(y), if y is a 1-1 matrix
hi = simplifyNnzComputation(hop, hi, i); //e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
@@ -1335,6 +1342,609 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
return hi;
}
+
+ /**
+ * Searches for weighted squared loss expressions and replaces them with a quaternary operator.
+ * Currently, this search includes the following three patterns:
+ * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
+ * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
+ * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
+ *
+ * NOTE: We include transpose into the pattern because during runtime we need to compute
+ * U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
+ * without additional memory requirements for internal transpose.
+ *
+ * This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
+ * U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
+ * rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
+ * as we generalized the runtime or hop/lop compilation.
+ *
+ * @param parent
+ * @param hi
+ * @param pos
+ * @return
+ * @throws HopsException
+ */
+ private Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos)
+ throws HopsException
+ {
+ //NOTE: there might be also a general simplification without custom operator
+ //via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
+ Hop hnew = null;
+
+ if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
+ && ((AggUnaryOp)hi).getOp() == AggOp.SUM //all patterns rooted by sum()
+ && hi.getInput().get(0) instanceof BinaryOp //all patterns subrooted by binary op
+ && hi.getInput().get(0).getDim2() > 1 ) //not applied for vector-vector mult
+ {
+ BinaryOp bop = (BinaryOp) hi.getInput().get(0);
+ boolean appliedPattern = false;
+
+ //Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
+ //alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
+ if( bop.getOp()==OpOp2.MULT && bop.getInput().get(1) instanceof BinaryOp
+ && bop.getInput().get(0).getDataType()==DataType.MATRIX
+ && HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) //prevent mv
+ && ((BinaryOp)bop.getInput().get(1)).getOp()==OpOp2.POW
+ && bop.getInput().get(1).getInput().get(1) instanceof LiteralOp
+ && HopRewriteUtils.getIntValue((LiteralOp)bop.getInput().get(1).getInput().get(1))==2)
+ {
+ Hop W = bop.getInput().get(0);
+ Hop tmp = bop.getInput().get(1).getInput().get(0); //(X - U %*% t(V))
+
+ if( tmp instanceof BinaryOp && ((BinaryOp)tmp).getOp()==OpOp2.MINUS
+ && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv
+ && tmp.getInput().get(0).getDataType() == DataType.MATRIX )
+ {
+ //a) sum (W * (X - U %*% t(V)) ^ 2)
+ int uvIndex = -1;
+ if( tmp.getInput().get(1) instanceof AggBinaryOp //ba gurantees matrices
+ && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT
+ {
+ uvIndex = 1;
+ }
+ //b) sum (W * (U %*% t(V) - X) ^ 2)
+ else if(tmp.getInput().get(0) instanceof AggBinaryOp //ba gurantees matrices
+ && HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT
+ {
+ uvIndex = 0;
+ }
+
+ if( uvIndex >= 0 ) //rewrite match
+ {
+ Hop X = tmp.getInput().get((uvIndex==0)?1:0);
+ Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
+ Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
+
+ if( !HopRewriteUtils.isTransposeOperation(V) ) {
+ V = HopRewriteUtils.createTranspose(V);
+ }
+ else{
+ V = V.getInput().get(0);
+ }
+
+ //handle special case of post_nz
+ if( HopRewriteUtils.isNonZeroIndicator(W, X) ){
+ W = new LiteralOp(1);
+ }
+
+ //construct quaternary hop
+ hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE,
+ OpOp4.WSLOSS, X, U, V, W, true);
+ HopRewriteUtils.setOutputParametersForScalar(hnew);
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedSquaredLoss1"+uvIndex+" (line "+hi.getBeginLine()+")");
+ }
+ }
+ }
+
+ //Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
+ //alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
+ if( !appliedPattern
+ && bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp
+ && HopRewriteUtils.getIntValue((LiteralOp)bop.getInput().get(1))==2
+ && bop.getInput().get(0) instanceof BinaryOp
+ && bop.getInput().get(0).getDataType()==DataType.MATRIX
+ && ((BinaryOp)bop.getInput().get(0)).getOp()==OpOp2.MINUS
+ && HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv
+ && bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX)
+ {
+ Hop lleft = bop.getInput().get(0).getInput().get(0);
+ Hop lright = bop.getInput().get(0).getInput().get(1);
+
+ //a) sum ((X - W * (U %*% t(V))) ^ 2)
+ int wuvIndex = -1;
+ if( lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp ){
+ wuvIndex = 1;
+ }
+ //b) sum ((W * (U %*% t(V)) - X) ^ 2)
+ else if( lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp ){
+ wuvIndex = 0;
+ }
+
+ if( wuvIndex >= 0 ) //rewrite match
+ {
+ Hop X = bop.getInput().get(0).getInput().get((wuvIndex==0)?1:0);
+ Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex); //(W * (U %*% t(V)))
+
+ if( ((BinaryOp)tmp).getOp()==OpOp2.MULT
+ && tmp.getInput().get(0).getDataType() == DataType.MATRIX
+ && HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv
+ && HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT
+ {
+ Hop W = tmp.getInput().get(0);
+ Hop U = tmp.getInput().get(1).getInput().get(0);
+ Hop V = tmp.getInput().get(1).getInput().get(1);
+
+ if( !HopRewriteUtils.isTransposeOperation(V) ) {
+ V = HopRewriteUtils.createTranspose(V);
+ }
+ else {
+ V = V.getInput().get(0);
+ }
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE,
+ OpOp4.WSLOSS, X, U, V, W, false);
+ HopRewriteUtils.setOutputParametersForScalar(hnew);
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedSquaredLoss2"+wuvIndex+" (line "+hi.getBeginLine()+")");
+ }
+ }
+ }
+
+ //Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
+ //alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
+ if( !appliedPattern
+ && bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp
+ && HopRewriteUtils.getIntValue((LiteralOp)bop.getInput().get(1))==2
+ && bop.getInput().get(0) instanceof BinaryOp
+ && bop.getInput().get(0).getDataType()==DataType.MATRIX
+ && ((BinaryOp)bop.getInput().get(0)).getOp()==OpOp2.MINUS
+ && HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv
+ && bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX)
+ {
+ Hop lleft = bop.getInput().get(0).getInput().get(0);
+ Hop lright = bop.getInput().get(0).getInput().get(1);
+
+ //a) sum ((X - (U %*% t(V))) ^ 2)
+ int uvIndex = -1;
+ if( lright instanceof AggBinaryOp //ba gurantees matrices
+ && HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+ {
+ uvIndex = 1;
+ }
+ //b) sum (((U %*% t(V)) - X) ^ 2)
+ else if( lleft instanceof AggBinaryOp //ba gurantees matrices
+ && HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+ {
+ uvIndex = 0;
+ }
+
+ if( uvIndex >= 0 ) //rewrite match
+ {
+ Hop X = bop.getInput().get(0).getInput().get((uvIndex==0)?1:0);
+ Hop tmp = bop.getInput().get(0).getInput().get(uvIndex); //(U %*% t(V))
+ Hop W = new LiteralOp(1); //no weighting
+ Hop U = tmp.getInput().get(0);
+ Hop V = tmp.getInput().get(1);
+
+ if( !HopRewriteUtils.isTransposeOperation(V) ) {
+ V = HopRewriteUtils.createTranspose(V);
+ }
+ else {
+ V = V.getInput().get(0);
+ }
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE,
+ OpOp4.WSLOSS, X, U, V, W, false);
+ HopRewriteUtils.setOutputParametersForScalar(hnew);
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedSquaredLoss3"+uvIndex+" (line "+hi.getBeginLine()+")");
+ }
+ }
+ }
+
+ //relink new hop into original position
+ if( hnew != null ) {
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.addChildReference(parent, hnew, pos);
+ hi = hnew;
+ }
+
+ return hi;
+ }
+
+ /**
+ *
+ * @param parent
+ * @param hi
+ * @param pos
+ * @return
+ * @throws HopsException
+ */
+ private Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos)
+ throws HopsException
+ {
+ Hop hnew = null;
+
+ if( hi instanceof BinaryOp //all patterns subrooted by W *
+ && ((BinaryOp) hi).getOp()==OpOp2.MULT
+ && hi.getDim2() > 1 //not applied for vector-vector mult
+ && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv
+ && hi.getInput().get(0).getDataType()==DataType.MATRIX
+ && hi.getInput().get(1) instanceof UnaryOp ) //sigmoid/log
+ {
+ UnaryOp uop = (UnaryOp) hi.getInput().get(1);
+ boolean appliedPattern = false;
+
+ //Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
+ if( uop.getOp() == OpOp1.SIGMOID
+ && uop.getInput().get(0) instanceof AggBinaryOp
+ && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0),true) )
+ {
+ Hop W = hi.getInput().get(0);
+ Hop Y = uop.getInput().get(0).getInput().get(0);
+ Hop tX = uop.getInput().get(0).getInput().get(1);
+
+ if( !HopRewriteUtils.isTransposeOperation(tX) ) {
+ tX = HopRewriteUtils.createTranspose(tX);
+ }
+ else
+ tX = tX.getInput().get(0);
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WSIGMOID, W, Y, tX, false, false);
+ HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedSigmoid1 (line "+hi.getBeginLine()+")");
+ }
+
+ //Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
+ if( !appliedPattern
+ && uop.getOp() == OpOp1.SIGMOID
+ && uop.getInput().get(0) instanceof BinaryOp
+ && ((BinaryOp)uop.getInput().get(0)).getOp()==OpOp2.MINUS
+ && uop.getInput().get(0).getInput().get(0) instanceof LiteralOp
+ && HopRewriteUtils.getDoubleValueSafe(
+ (LiteralOp)uop.getInput().get(0).getInput().get(0))==0
+ && uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp
+ && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0),true))
+ {
+ Hop W = hi.getInput().get(0);
+ Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0);
+ Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1);
+
+ if( !HopRewriteUtils.isTransposeOperation(tX) ) {
+ tX = HopRewriteUtils.createTranspose(tX);
+ }
+ else
+ tX = tX.getInput().get(0);
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WSIGMOID, W, Y, tX, false, true);
+ HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedSigmoid2 (line "+hi.getBeginLine()+")");
+ }
+
+ //Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
+ if( !appliedPattern
+ && uop.getOp() == OpOp1.LOG
+ && uop.getInput().get(0) instanceof UnaryOp
+ && ((UnaryOp)uop.getInput().get(0)).getOp() == OpOp1.SIGMOID
+ && uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp
+ && HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0),true) )
+ {
+ Hop W = hi.getInput().get(0);
+ Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0);
+ Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1);
+
+ if( !HopRewriteUtils.isTransposeOperation(tX) ) {
+ tX = HopRewriteUtils.createTranspose(tX);
+ }
+ else
+ tX = tX.getInput().get(0);
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WSIGMOID, W, Y, tX, true, false);
+ HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedSigmoid3 (line "+hi.getBeginLine()+")");
+ }
+
+ //Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
+ if( !appliedPattern
+ && uop.getOp() == OpOp1.LOG
+ && uop.getInput().get(0) instanceof UnaryOp
+ && ((UnaryOp)uop.getInput().get(0)).getOp() == OpOp1.SIGMOID
+ && uop.getInput().get(0).getInput().get(0) instanceof BinaryOp )
+ {
+ BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
+
+ if( bop.getOp() == OpOp2.MINUS
+ && bop.getInput().get(0) instanceof LiteralOp
+ && HopRewriteUtils.getDoubleValueSafe((LiteralOp)bop.getInput().get(0))==0
+ && bop.getInput().get(1) instanceof AggBinaryOp
+ && HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0),true))
+ {
+ Hop W = hi.getInput().get(0);
+ Hop Y = bop.getInput().get(1).getInput().get(0);
+ Hop tX = bop.getInput().get(1).getInput().get(1);
+
+ if( !HopRewriteUtils.isTransposeOperation(tX) ) {
+ tX = HopRewriteUtils.createTranspose(tX);
+ }
+ else
+ tX = tX.getInput().get(0);
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WSIGMOID, W, Y, tX, true, true);
+ HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedSigmoid4 (line "+hi.getBeginLine()+")");
+ }
+ }
+ }
+
+ //relink new hop into original position
+ if( hnew != null ) {
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.addChildReference(parent, hnew, pos);
+ hi = hnew;
+ }
+
+ return hi;
+ }
+
+ /**
+ *
+ * @param parent
+ * @param hi
+ * @param pos
+ * @return
+ * @throws HopsException
+ */
+ private Hop simplifyWeightedDivMM(Hop parent, Hop hi, int pos)
+ throws HopsException
+ {
+ Hop hnew = null;
+ boolean appliedPattern = false;
+
+ //left/right patterns rooted by 'ab - b(div)' or 'ab - b(mult)'
+ //note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are vectors (see mmchain ops)
+ if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply()
+ && (hi.getInput().get(0) instanceof BinaryOp
+ && HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(0)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)
+ || hi.getInput().get(1) instanceof BinaryOp
+ && hi.getDim2() > 1 //not applied for vector-vector mult
+ && HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(1)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)) )
+ {
+ Hop left = hi.getInput().get(0);
+ Hop right = hi.getInput().get(1);
+
+ //Pattern 1) t(U) %*% (W/(U%*%t(V)))
+ //alternative pattern: t(U) %*% (W*(U%*%t(V)))
+ if( right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)right).getOp(),LOOKUP_VALID_WDIVMM_BINARY)
+ && HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) //prevent mv
+ && right.getInput().get(1) instanceof AggBinaryOp
+ && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+ {
+ Hop W = right.getInput().get(0);
+ Hop U = right.getInput().get(1).getInput().get(0);
+ Hop V = right.getInput().get(1).getInput().get(1);
+
+ if( HopRewriteUtils.isTransposeOfItself(left, U) )
+ {
+ if( !HopRewriteUtils.isTransposeOperation(V) )
+ V = HopRewriteUtils.createTranspose(V);
+ else
+ V = V.getInput().get(0);
+
+ boolean mult = ((BinaryOp)right).getOp() == OpOp2.MULT;
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WDIVMM, W, U, V, 1, mult, false);
+ HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+ //add output transpose for efficient target indexing (redundant t() removed by other rewrites)
+ hnew = HopRewriteUtils.createTranspose(hnew);
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedDivMM1 (line "+hi.getBeginLine()+")");
+ }
+ }
+
+ //Pattern 2) (W/(U%*%t(V))) %*% V
+ //alternative pattern: (W*(U%*%t(V))) %*% V
+ if( !appliedPattern
+ && left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)left).getOp(), LOOKUP_VALID_WDIVMM_BINARY)
+ && HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) //prevent mv
+ && left.getInput().get(1) instanceof AggBinaryOp
+ && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+ {
+ Hop W = left.getInput().get(0);
+ Hop U = left.getInput().get(1).getInput().get(0);
+ Hop V = left.getInput().get(1).getInput().get(1);
+
+ if( HopRewriteUtils.isTransposeOfItself(right, V) )
+ {
+ if( !HopRewriteUtils.isTransposeOperation(V) )
+ V = right;
+ else
+ V = V.getInput().get(0);
+
+ boolean mult = ((BinaryOp)left).getOp() == OpOp2.MULT;
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WDIVMM, W, U, V, 2, mult, false);
+ HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedDivMM2 (line "+hi.getBeginLine()+")");
+ }
+ }
+
+ //Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X))
+ if( right instanceof BinaryOp && ((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
+ && right.getInput().get(1) instanceof BinaryOp && ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS
+ && right.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
+ && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
+ && HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+ {
+ Hop W = right.getInput().get(0);
+ Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
+ Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
+ Hop X = right.getInput().get(1).getInput().get(1);
+
+ if( HopRewriteUtils.isNonZeroIndicator(W, X) //W-X constraint
+ && HopRewriteUtils.isTransposeOfItself(left, U) ) //t(U)-U constraint
+ {
+ if( !HopRewriteUtils.isTransposeOperation(V) )
+ V = HopRewriteUtils.createTranspose(V);
+ else
+ V = V.getInput().get(0);
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WDIVMM, X, U, V, 1, true, true);
+ HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+ //add output transpose for efficient target indexing (redundant t() removed by other rewrites)
+ hnew = HopRewriteUtils.createTranspose(hnew);
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedDivMM3 (line "+hi.getBeginLine()+")");
+ }
+ }
+
+ //Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V
+ if( !appliedPattern
+ && left instanceof BinaryOp && ((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
+ && left.getInput().get(1) instanceof BinaryOp && ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS
+ && left.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
+ && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
+ && HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+ {
+ Hop W = left.getInput().get(0);
+ Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
+ Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
+ Hop X = left.getInput().get(1).getInput().get(1);
+
+ if( HopRewriteUtils.isNonZeroIndicator(W, X) //W-X constraint
+ && HopRewriteUtils.isTransposeOfItself(right, V) ) //V-t(V) constraint
+ {
+ if( !HopRewriteUtils.isTransposeOperation(V) )
+ V = right;
+ else
+ V = V.getInput().get(0);
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WDIVMM, X, U, V, 2, true, true);
+ HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedDivMM4 (line "+hi.getBeginLine()+")");
+ }
+ }
+ }
+
+ //Pattern 5) (W*(U%*%t(V)))
+ if( !appliedPattern
+ && hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
+ && HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv
+ && hi.getDim2() > 1 //not applied for vector-vector mult
+ && hi.getInput().get(0).getDataType() == DataType.MATRIX
+ && hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock()
+ && hi.getInput().get(1) instanceof AggBinaryOp
+ && (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
+ && HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+ {
+ Hop W = hi.getInput().get(0);
+ Hop U = hi.getInput().get(1).getInput().get(0);
+ Hop V = hi.getInput().get(1).getInput().get(1);
+
+ if( !HopRewriteUtils.isTransposeOperation(V) )
+ V = HopRewriteUtils.createTranspose(V);
+ else
+ V = V.getInput().get(0);
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WDIVMM, W, U, V, 0, true, false);
+ HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedDivMM5 (line "+hi.getBeginLine()+")");
+ }
+
+ //relink new hop into original position
+ if( hnew != null ) {
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.addChildReference(parent, hnew, pos);
+ hi = hnew;
+ }
+
+ return hi;
+ }
+
+ /**
+ *
+ * @param parent
+ * @param hi
+ * @param pos
+ * @return
+ * @throws HopsException
+ */
+ private Hop simplifyWeightedCrossEntropy(Hop parent, Hop hi, int pos)
+ throws HopsException
+ {
+ Hop hnew = null;
+
+ //Pattern 1) sum( X * log(U %*% t(V)))
+ if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
+ && ((AggUnaryOp)hi).getOp() == AggOp.SUM //pattern rooted by sum()
+ && hi.getInput().get(0) instanceof BinaryOp //pattern subrooted by binary op
+ && hi.getInput().get(0).getDim2() > 1 ) //not applied for vector-vector mult
+ {
+ BinaryOp bop = (BinaryOp) hi.getInput().get(0);
+ Hop left = bop.getInput().get(0);
+ Hop right = bop.getInput().get(1);
+
+ if( bop.getOp()==OpOp2.MULT && left.getDataType()==DataType.MATRIX
+ && HopRewriteUtils.isEqualSize(left, right) //prevent mb
+ && right instanceof UnaryOp && ((UnaryOp)right).getOp()==OpOp1.LOG
+ && right.getInput().get(0) instanceof AggBinaryOp //ba gurantees matrices
+ && HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT
+ {
+ Hop X = left;
+ Hop U = right.getInput().get(0).getInput().get(0);
+ Hop V = right.getInput().get(0).getInput().get(1);
+
+ if( !HopRewriteUtils.isTransposeOperation(V) )
+ V = HopRewriteUtils.createTranspose(V);
+ else
+ V = V.getInput().get(0);
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V);
+ HopRewriteUtils.setOutputBlocksizes(hnew, X.getRowsInBlock(), X.getColsInBlock());
+
+ LOG.debug("Applied simplifyWeightedCEMM (line "+hi.getBeginLine()+")");
+ }
+ }
+
+ //relink new hop into original position
+ if( hnew != null ) {
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.addChildReference(parent, hnew, pos);
+ hi = hnew;
+ }
+
+ return hi;
+ }
+
/**
* NOTE: dot-product-sum could be also applied to sum(a*b). However, we
* restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e52e0c0a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 4fc23eb..d2887a2 100644
--- a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -29,9 +29,7 @@ import com.ibm.bi.dml.hops.BinaryOp;
import com.ibm.bi.dml.hops.DataGenOp;
import com.ibm.bi.dml.hops.Hop;
import com.ibm.bi.dml.hops.Hop.OpOp1;
-import com.ibm.bi.dml.hops.Hop.OpOp4;
import com.ibm.bi.dml.hops.IndexingOp;
-import com.ibm.bi.dml.hops.QuaternaryOp;
import com.ibm.bi.dml.hops.TernaryOp;
import com.ibm.bi.dml.hops.UnaryOp;
import com.ibm.bi.dml.hops.Hop.AggOp;
@@ -44,7 +42,6 @@ import com.ibm.bi.dml.hops.LiteralOp;
import com.ibm.bi.dml.hops.Hop.OpOp2;
import com.ibm.bi.dml.hops.ParameterizedBuiltinOp;
import com.ibm.bi.dml.hops.ReorgOp;
-import com.ibm.bi.dml.lops.MapMultChain.ChainType;
import com.ibm.bi.dml.parser.DataExpression;
import com.ibm.bi.dml.parser.Statement;
import com.ibm.bi.dml.parser.Expression.DataType;
@@ -66,7 +63,6 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
private static OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new OpOp2[]{OpOp2.PLUS, OpOp2.MINUS};
private static OpOp2[] LOOKUP_VALID_ASSOCIATIVE_BINARY = new OpOp2[]{OpOp2.PLUS, OpOp2.MULT};
- private static OpOp2[] LOOKUP_VALID_WDIVMM_BINARY = new OpOp2[]{OpOp2.MULT, OpOp2.DIV};
@Override
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state)
@@ -149,10 +145,6 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
hi = removeUnnecessaryTranspose(hop, hi, i); //e.g., t(t(X))->X; potentially introduced by diag/trace_MM
hi = removeUnnecessaryMinus(hop, hi, i); //e.g., -(-X)->X; potentially introduced by simplfiy binary or dyn rewrites
hi = simplifyGroupedAggregate(hi); //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count")
- hi = simplifyWeightedSquaredLoss(hop, hi, i); //e.g., sum(W * (X - U %*% t(V)) ^ 2) -> wsl(X, U, t(V), W, true)
- hi = simplifyWeightedSigmoidMMChains(hop, hi, i); //e.g., W * sigmoid(Y%*%t(X)) -> wsigmoid(W, Y, t(X), type)
- hi = simplifyWeightedDivMM(hop, hi, i); //e.g., t(U) %*% (X/(U%*%t(V))) -> wdivmm(X, U, t(V), left)
- hi = simplifyWeightedCrossEntropy(hop, hi, i); //e.g., sum(X*log(U%*%t(V))) -> wcemm(X, U, t(V))
hi = fuseMinusNzBinaryOperation(hop, hi, i); //e.g., X-mean*ppred(X,0,!=) -> X -nz mean
hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false)
@@ -1298,6 +1290,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
LOG.debug("Applied simplifyGroupedAggregateCount");
}
}
+<<<<<<< Upstream, based on branch 'master' of https://git-wip-us.apache.org/repos/asf/incubator-systemml.git
}
return hi;
@@ -1900,6 +1893,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
HopRewriteUtils.addChildReference(parent, hnew, pos);
hi = hnew;
+=======
+>>>>>>> 04aa86c Fix rewrite 'fuse sum_sq' (after wsloss rewrite), for kmeans_predict
}
return hi;