You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2015/11/23 04:53:00 UTC

[5/8] incubator-systemml git commit: Fix rewrite 'fuse sum_sq' (after wsloss rewrite), for kmeans_predict

Fix rewrite 'fuse sum_sq' (after wsloss rewrite), for kmeans_predict 

The new rewrite sum(X^2) -> sum_sq(X) was mistakenly applied before
sum((X-L%*%R)^2) -> wsloss(X,L,R) and hence led to a performance
regression due to shuffle for the join between X and L%*%R. This change
moved all quaternary rewrites to the category of 'dynamic' rewrites
because they anyway check for various size parameters. By applying the
quaternary rewrites before the 'fuse sum_sq', both rewrites trigger for
the intended patterns. On a kmeans 80GB use case, this change reduced
the end-to-end runtime from 590s to 114s.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/e52e0c0a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/e52e0c0a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/e52e0c0a

Branch: refs/heads/master
Commit: e52e0c0a6b5a39e375383b046e3ce0465ff8d662
Parents: cc4aae7
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sat Nov 21 00:51:27 2015 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sun Nov 22 19:38:23 2015 -0800

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationDynamic.java  | 616 ++++++++++++++++++-
 .../RewriteAlgebraicSimplificationStatic.java   |  11 +-
 2 files changed, 616 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e52e0c0a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 5c7a0fb..10dbedf 100644
--- a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -28,10 +28,12 @@ import com.ibm.bi.dml.hops.AggUnaryOp;
 import com.ibm.bi.dml.hops.BinaryOp;
 import com.ibm.bi.dml.hops.DataGenOp;
 import com.ibm.bi.dml.hops.Hop;
+import com.ibm.bi.dml.hops.QuaternaryOp;
 import com.ibm.bi.dml.hops.Hop.AggOp;
 import com.ibm.bi.dml.hops.Hop.DataGenMethod;
 import com.ibm.bi.dml.hops.Hop.Direction;
 import com.ibm.bi.dml.hops.Hop.OpOp1;
+import com.ibm.bi.dml.hops.Hop.OpOp4;
 import com.ibm.bi.dml.hops.Hop.ReOrgOp;
 import com.ibm.bi.dml.hops.HopsException;
 import com.ibm.bi.dml.hops.IndexingOp;
@@ -40,6 +42,7 @@ import com.ibm.bi.dml.hops.LiteralOp;
 import com.ibm.bi.dml.hops.Hop.OpOp2;
 import com.ibm.bi.dml.hops.ReorgOp;
 import com.ibm.bi.dml.hops.UnaryOp;
+import com.ibm.bi.dml.lops.MapMultChain.ChainType;
 import com.ibm.bi.dml.parser.DMLTranslator;
 import com.ibm.bi.dml.parser.DataExpression;
 import com.ibm.bi.dml.parser.Expression.DataType;
@@ -56,10 +59,8 @@ import com.ibm.bi.dml.parser.Expression.ValueType;
  */
 public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 {
-	
 	private static final Log LOG = LogFactory.getLog(RewriteAlgebraicSimplificationDynamic.class.getName());
 	
-	
 	//valid aggregation operation types for rowOp to Op conversions (not all operations apply)
 	private static AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN};
 	
@@ -70,6 +71,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 	//valid unary operation types for empty (sparse-safe) operations (not all operations apply)
 	private static OpOp1[] LOOKUP_VALID_EMPTY_UNARY = new OpOp1[]{OpOp1.ABS, OpOp1.SIN, OpOp1.TAN, OpOp1.SQRT, OpOp1.ROUND, OpOp1.CUMSUM}; 
 	
+	//valid pseudo-sparse-safe binary operators for wdivmm 
+	private static OpOp2[] LOOKUP_VALID_WDIVMM_BINARY = new OpOp2[]{OpOp2.MULT, OpOp2.DIV}; 
 	
 	
 	@Override
@@ -159,10 +162,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			hi = simplifyDiagMatrixMult(hop, hi, i);          //e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector
 			hi = simplifySumDiagToTrace(hi);                  //e.g., sum(diag(X)) -> trace(X); if col vector
 			hi = pushdownBinaryOperationOnDiag(hop, hi, i);   //e.g., diag(X)*7 -> diag(X*7); if col vector
+			hi = simplifyWeightedSquaredLoss(hop, hi, i);     //e.g., sum(W * (X - U %*% t(V)) ^ 2) -> wsl(X, U, t(V), W, true), 
+			hi = simplifyWeightedSigmoidMMChains(hop, hi, i); //e.g., W * sigmoid(Y%*%t(X)) -> wsigmoid(W, Y, t(X), type)
+			hi = simplifyWeightedDivMM(hop, hi, i);           //e.g., t(U) %*% (X/(U%*%t(V))) -> wdivmm(X, U, t(V), left)
+			hi = simplifyWeightedCrossEntropy(hop, hi, i);    //e.g., sum(X*log(U%*%t(V))) -> wcemm(X, U, t(V))
 			hi = simplifyDotProductSum(hop, hi, i);           //e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1 
 			hi = fuseSumSquared(hop, hi, i);                  //e.g., sum(X^2) -> sumSq(X), if ncol(X)>1
 			hi = reorderMinusMatrixMult(hop, hi, i);          //e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size
-			hi = simplifySumMatrixMult(hop, hi, i);           //e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product
+			hi = simplifySumMatrixMult(hop, hi, i);           //e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product / wsloss
 			hi = simplifyEmptyBinaryOperation(hop, hi, i);    //e.g., X*Y -> matrix(0,nrow(X), ncol(X)) / X+Y->X / X-Y -> X
 			hi = simplifyScalarMVBinaryOperation(hi); 		  //e.g., X*y -> X*as.scalar(y), if y is a 1-1 matrix
 			hi = simplifyNnzComputation(hop, hi, i);          //e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
@@ -1335,6 +1342,609 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 		return hi;
 	}
 	
+
+	/**
+	 * Searches for weighted squared loss expressions and replaces them with a quaternary operator. 
+	 * Currently, this search includes the following three patterns:
+	 * 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
+	 * 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
+	 * 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
+	 * 
+	 * NOTE: We include transpose into the pattern because during runtime we need to compute
+	 * U%*% t(V) pointwise; having V and not t(V) at hand allows for a cache-friendly implementation
+	 * without additional memory requirements for internal transpose.
+	 * 
+	 * This rewrite is conceptually a static rewrite; however, the current MR runtime only supports
+	 * U/V factors of rank up to the blocksize (1000). We enforce this contraint here during the general
+	 * rewrite because this is an uncommon case. Also, the intention is to remove this constaint as soon
+	 * as we generalized the runtime or hop/lop compilation. 
+	 * 
+	 * @param parent
+	 * @param hi
+	 * @param pos
+	 * @return
+	 * @throws HopsException 
+	 */
+	private Hop simplifyWeightedSquaredLoss(Hop parent, Hop hi, int pos) 
+		throws HopsException
+	{
+		//NOTE: there might be also a general simplification without custom operator
+		//via (X-UVt)^2 -> X^2 - 2X*UVt + UVt^2
+		Hop hnew = null;
+		
+		if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
+			&& ((AggUnaryOp)hi).getOp() == AggOp.SUM     //all patterns rooted by sum()
+			&& hi.getInput().get(0) instanceof BinaryOp  //all patterns subrooted by binary op
+			&& hi.getInput().get(0).getDim2() > 1  )     //not applied for vector-vector mult
+		{
+			BinaryOp bop = (BinaryOp) hi.getInput().get(0);
+			boolean appliedPattern = false;
+			
+			//Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post weighting)
+			//alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
+			if( bop.getOp()==OpOp2.MULT && bop.getInput().get(1) instanceof BinaryOp	
+				&& bop.getInput().get(0).getDataType()==DataType.MATRIX	
+				&& HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1)) //prevent mv
+				&& ((BinaryOp)bop.getInput().get(1)).getOp()==OpOp2.POW 
+				&& bop.getInput().get(1).getInput().get(1) instanceof LiteralOp
+				&& HopRewriteUtils.getIntValue((LiteralOp)bop.getInput().get(1).getInput().get(1))==2)
+			{
+				Hop W = bop.getInput().get(0);
+				Hop tmp = bop.getInput().get(1).getInput().get(0); //(X - U %*% t(V))
+				
+				if( tmp instanceof BinaryOp && ((BinaryOp)tmp).getOp()==OpOp2.MINUS
+					&& HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv	
+					&& tmp.getInput().get(0).getDataType() == DataType.MATRIX )
+				{
+					//a) sum (W * (X - U %*% t(V)) ^ 2)
+					int uvIndex = -1;
+					if( tmp.getInput().get(1) instanceof AggBinaryOp  //ba gurantees matrices
+							&& HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT
+					{
+						uvIndex = 1;   
+					}
+					//b) sum (W * (U %*% t(V) - X) ^ 2)
+					else if(tmp.getInput().get(0) instanceof AggBinaryOp  //ba gurantees matrices
+						&& HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT
+					{
+						uvIndex = 0;
+					}   
+				 
+					if( uvIndex >= 0 ) //rewrite match
+					{
+						Hop X = tmp.getInput().get((uvIndex==0)?1:0); 
+						Hop U = tmp.getInput().get(uvIndex).getInput().get(0);
+						Hop V = tmp.getInput().get(uvIndex).getInput().get(1);
+	                    
+						if( !HopRewriteUtils.isTransposeOperation(V) ) {
+							V = HopRewriteUtils.createTranspose(V);
+						}
+						else{
+							V = V.getInput().get(0);
+						}
+	                    
+						//handle special case of post_nz
+						if( HopRewriteUtils.isNonZeroIndicator(W, X) ){
+							W = new LiteralOp(1);
+						}
+						
+						//construct quaternary hop
+						hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, 
+								OpOp4.WSLOSS, X, U, V, W, true);
+						HopRewriteUtils.setOutputParametersForScalar(hnew);
+	
+						appliedPattern = true;
+						LOG.debug("Applied simplifyWeightedSquaredLoss1"+uvIndex+" (line "+hi.getBeginLine()+")");  
+					}
+				}
+			}
+			
+			//Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre weighting)
+			//alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
+			if( !appliedPattern
+				&& bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp
+				&& HopRewriteUtils.getIntValue((LiteralOp)bop.getInput().get(1))==2
+				&& bop.getInput().get(0) instanceof BinaryOp	
+				&& bop.getInput().get(0).getDataType()==DataType.MATRIX	
+				&& ((BinaryOp)bop.getInput().get(0)).getOp()==OpOp2.MINUS
+				&& HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv
+				&& bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX)
+			{
+			    Hop lleft = bop.getInput().get(0).getInput().get(0); 
+			    Hop lright = bop.getInput().get(0).getInput().get(1); 
+                
+			    //a) sum ((X - W * (U %*% t(V))) ^ 2)
+			    int wuvIndex = -1;
+			    if( lright instanceof BinaryOp && lright.getInput().get(1) instanceof AggBinaryOp ){
+			    	wuvIndex = 1;
+			    }
+			    //b) sum ((W * (U %*% t(V)) - X) ^ 2)
+			    else if( lleft instanceof BinaryOp && lleft.getInput().get(1) instanceof AggBinaryOp ){
+			    	wuvIndex = 0;
+			    }
+			    
+			    if( wuvIndex >= 0 ) //rewrite match
+			    {
+			    	Hop X = bop.getInput().get(0).getInput().get((wuvIndex==0)?1:0);
+			    	Hop tmp = bop.getInput().get(0).getInput().get(wuvIndex); //(W * (U %*% t(V)))
+    				
+    				if( ((BinaryOp)tmp).getOp()==OpOp2.MULT
+    					&& tmp.getInput().get(0).getDataType() == DataType.MATRIX	
+    					&& HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1)) //prevent mv
+    					&& HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT
+    				{
+    					Hop W = tmp.getInput().get(0); 
+    					Hop U = tmp.getInput().get(1).getInput().get(0);
+    					Hop V = tmp.getInput().get(1).getInput().get(1);
+    					
+    					if( !HopRewriteUtils.isTransposeOperation(V) ) { 
+    						V = HopRewriteUtils.createTranspose(V);
+    					}
+    					else {
+    						V = V.getInput().get(0);
+    					}
+    					
+    					hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, 
+    							  OpOp4.WSLOSS, X, U, V, W, false);
+    					HopRewriteUtils.setOutputParametersForScalar(hnew);
+    
+    					appliedPattern = true;
+    					LOG.debug("Applied simplifyWeightedSquaredLoss2"+wuvIndex+" (line "+hi.getBeginLine()+")");	
+    				}
+			    }
+			}
+			
+			//Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
+			//alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
+			if( !appliedPattern
+				&& bop.getOp()==OpOp2.POW && bop.getInput().get(1) instanceof LiteralOp
+				&& HopRewriteUtils.getIntValue((LiteralOp)bop.getInput().get(1))==2
+				&& bop.getInput().get(0) instanceof BinaryOp	
+				&& bop.getInput().get(0).getDataType()==DataType.MATRIX	
+				&& ((BinaryOp)bop.getInput().get(0)).getOp()==OpOp2.MINUS
+				&& HopRewriteUtils.isEqualSize(bop.getInput().get(0).getInput().get(0), bop.getInput().get(0).getInput().get(1)) //prevent mv
+				&& bop.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX)
+			{
+				Hop lleft = bop.getInput().get(0).getInput().get(0);
+				Hop lright = bop.getInput().get(0).getInput().get(1);
+                
+				//a) sum ((X - (U %*% t(V))) ^ 2)
+				int uvIndex = -1;
+				if( lright instanceof AggBinaryOp //ba gurantees matrices
+					&& HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) )  //BLOCKSIZE CONSTRAINT
+				{
+					uvIndex = 1;
+				}
+				//b) sum (((U %*% t(V)) - X) ^ 2)
+				else if( lleft instanceof AggBinaryOp //ba gurantees matrices
+						&& HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) )  //BLOCKSIZE CONSTRAINT
+				{
+					uvIndex = 0;
+				}
+			    
+				if( uvIndex >= 0 ) //rewrite match
+				{
+					Hop X = bop.getInput().get(0).getInput().get((uvIndex==0)?1:0);
+					Hop tmp = bop.getInput().get(0).getInput().get(uvIndex); //(U %*% t(V))
+					Hop W = new LiteralOp(1); //no weighting 
+					Hop U = tmp.getInput().get(0);
+					Hop V = tmp.getInput().get(1);
+	
+					if( !HopRewriteUtils.isTransposeOperation(V) ) { 
+						V = HopRewriteUtils.createTranspose(V);
+					}
+					else {
+						V = V.getInput().get(0);
+					}
+					
+					hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, 
+							  OpOp4.WSLOSS, X, U, V, W, false);
+					HopRewriteUtils.setOutputParametersForScalar(hnew);
+
+					appliedPattern = true;
+					LOG.debug("Applied simplifyWeightedSquaredLoss3"+uvIndex+" (line "+hi.getBeginLine()+")");	
+				}
+			}			
+		}
+		
+		//relink new hop into original position
+		if( hnew != null ) {
+			HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+			HopRewriteUtils.addChildReference(parent, hnew, pos);
+			hi = hnew;
+		}
+		
+		return hi;
+	}
+	
+	/**
+	 * 
+	 * @param parent
+	 * @param hi
+	 * @param pos
+	 * @return
+	 * @throws HopsException
+	 */
+	private Hop simplifyWeightedSigmoidMMChains(Hop parent, Hop hi, int pos) 
+		throws HopsException
+	{
+		Hop hnew = null;
+		
+		if(    hi instanceof BinaryOp //all patterns subrooted by W *
+			&& ((BinaryOp) hi).getOp()==OpOp2.MULT
+			&& hi.getDim2() > 1       //not applied for vector-vector mult
+			&& HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv
+			&& hi.getInput().get(0).getDataType()==DataType.MATRIX 
+			&& hi.getInput().get(1) instanceof UnaryOp ) //sigmoid/log
+		{
+			UnaryOp uop = (UnaryOp) hi.getInput().get(1);
+			boolean appliedPattern = false;
+			
+			//Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
+			if(    uop.getOp() == OpOp1.SIGMOID 
+				&& uop.getInput().get(0) instanceof AggBinaryOp
+				&& HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0),true) )
+			{
+				Hop W = hi.getInput().get(0); 
+				Hop Y = uop.getInput().get(0).getInput().get(0);
+				Hop tX = uop.getInput().get(0).getInput().get(1);
+				
+				if( !HopRewriteUtils.isTransposeOperation(tX) ) { 
+					tX = HopRewriteUtils.createTranspose(tX);
+				}
+				else 
+					tX = tX.getInput().get(0);
+				
+				hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, 
+						  OpOp4.WSIGMOID, W, Y, tX, false, false);
+				HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+				appliedPattern = true;
+				LOG.debug("Applied simplifyWeightedSigmoid1 (line "+hi.getBeginLine()+")");	
+			}
+			
+			//Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
+			if(    !appliedPattern 
+				&& uop.getOp() == OpOp1.SIGMOID 
+				&& uop.getInput().get(0) instanceof BinaryOp
+				&& ((BinaryOp)uop.getInput().get(0)).getOp()==OpOp2.MINUS
+				&& uop.getInput().get(0).getInput().get(0) instanceof LiteralOp
+				&& HopRewriteUtils.getDoubleValueSafe(
+				   (LiteralOp)uop.getInput().get(0).getInput().get(0))==0
+				&& uop.getInput().get(0).getInput().get(1) instanceof AggBinaryOp
+				&& HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0),true))
+			{
+				Hop W = hi.getInput().get(0); 
+				Hop Y = uop.getInput().get(0).getInput().get(1).getInput().get(0);
+				Hop tX = uop.getInput().get(0).getInput().get(1).getInput().get(1);
+				
+				if( !HopRewriteUtils.isTransposeOperation(tX) ) { 
+					tX = HopRewriteUtils.createTranspose(tX);
+				}
+				else 
+					tX = tX.getInput().get(0);
+				
+				hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, 
+						  OpOp4.WSIGMOID, W, Y, tX, false, true);
+				HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+				appliedPattern = true;
+				LOG.debug("Applied simplifyWeightedSigmoid2 (line "+hi.getBeginLine()+")");	
+			}
+			
+			//Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)			
+			if(    !appliedPattern 
+				&& uop.getOp() == OpOp1.LOG
+				&& uop.getInput().get(0) instanceof UnaryOp
+				&& ((UnaryOp)uop.getInput().get(0)).getOp() == OpOp1.SIGMOID 
+				&& uop.getInput().get(0).getInput().get(0) instanceof AggBinaryOp
+				&& HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0),true) )
+			{
+				Hop W = hi.getInput().get(0); 
+				Hop Y = uop.getInput().get(0).getInput().get(0).getInput().get(0);
+				Hop tX = uop.getInput().get(0).getInput().get(0).getInput().get(1);
+				
+				if( !HopRewriteUtils.isTransposeOperation(tX) ) { 
+					tX = HopRewriteUtils.createTranspose(tX);
+				}
+				else 
+					tX = tX.getInput().get(0);
+				
+				hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, 
+						  OpOp4.WSIGMOID, W, Y, tX, true, false);
+				HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+				appliedPattern = true;
+				LOG.debug("Applied simplifyWeightedSigmoid3 (line "+hi.getBeginLine()+")");	
+			}			
+			
+			//Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
+			if(    !appliedPattern 
+				&& uop.getOp() == OpOp1.LOG
+				&& uop.getInput().get(0) instanceof UnaryOp
+				&& ((UnaryOp)uop.getInput().get(0)).getOp() == OpOp1.SIGMOID 
+				&& uop.getInput().get(0).getInput().get(0) instanceof BinaryOp )
+			{
+				BinaryOp bop = (BinaryOp) uop.getInput().get(0).getInput().get(0);
+				
+				if(    bop.getOp() == OpOp2.MINUS 
+					&& bop.getInput().get(0) instanceof LiteralOp
+					&& HopRewriteUtils.getDoubleValueSafe((LiteralOp)bop.getInput().get(0))==0
+					&& bop.getInput().get(1) instanceof AggBinaryOp
+					&& HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0),true))
+				{
+					Hop W = hi.getInput().get(0); 
+					Hop Y = bop.getInput().get(1).getInput().get(0);
+					Hop tX = bop.getInput().get(1).getInput().get(1);
+					
+					if( !HopRewriteUtils.isTransposeOperation(tX) ) { 
+						tX = HopRewriteUtils.createTranspose(tX);
+					}
+					else 
+						tX = tX.getInput().get(0);
+					
+					hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, 
+							  OpOp4.WSIGMOID, W, Y, tX, true, true);
+					HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+	
+					appliedPattern = true;
+					LOG.debug("Applied simplifyWeightedSigmoid4 (line "+hi.getBeginLine()+")");	
+				}
+			}
+		}
+		
+		//relink new hop into original position
+		if( hnew != null ) {
+			HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+			HopRewriteUtils.addChildReference(parent, hnew, pos);
+			hi = hnew;
+		}
+		
+		return hi;
+	}
+
+	/**
+	 * 
+	 * @param parent
+	 * @param hi
+	 * @param pos
+	 * @return
+	 * @throws HopsException
+	 */
+	private Hop simplifyWeightedDivMM(Hop parent, Hop hi, int pos) 
+		throws HopsException
+	{
+		Hop hnew = null;
+		boolean appliedPattern = false;
+		
+		//left/right patterns rooted by 'ab - b(div)' or 'ab - b(mult)'
+		//note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are vectors (see mmchain ops) 
+		if( hi instanceof AggBinaryOp && ((AggBinaryOp)hi).isMatrixMultiply()  
+			&& (hi.getInput().get(0) instanceof BinaryOp
+			&& HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(0)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)
+			|| hi.getInput().get(1) instanceof BinaryOp 
+			&& hi.getDim2() > 1 //not applied for vector-vector mult
+			&& HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(1)).getOp(), LOOKUP_VALID_WDIVMM_BINARY)) ) 
+		{
+			Hop left = hi.getInput().get(0);
+			Hop right = hi.getInput().get(1);
+			
+			//Pattern 1) t(U) %*% (W/(U%*%t(V)))
+			//alternative pattern: t(U) %*% (W*(U%*%t(V)))
+			if( right instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)right).getOp(),LOOKUP_VALID_WDIVMM_BINARY)	
+				&& HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1)) //prevent mv
+				&& right.getInput().get(1) instanceof AggBinaryOp
+				&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+			{
+				Hop W = right.getInput().get(0); 
+				Hop U = right.getInput().get(1).getInput().get(0);
+				Hop V = right.getInput().get(1).getInput().get(1);
+				
+				if( HopRewriteUtils.isTransposeOfItself(left, U) ) 
+				{
+					if( !HopRewriteUtils.isTransposeOperation(V) )
+						V = HopRewriteUtils.createTranspose(V);
+					else 
+						V = V.getInput().get(0);
+					
+					boolean mult = ((BinaryOp)right).getOp() == OpOp2.MULT;
+					hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, 
+							  OpOp4.WDIVMM, W, U, V, 1, mult, false);
+					HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+					
+					//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
+					hnew = HopRewriteUtils.createTranspose(hnew);
+					
+					appliedPattern = true;
+					LOG.debug("Applied simplifyWeightedDivMM1 (line "+hi.getBeginLine()+")");					
+				}
+			}	
+			
+			//Pattern 2) (W/(U%*%t(V))) %*% V
+			//alternative pattern: (W*(U%*%t(V))) %*% V
+			if( !appliedPattern
+				&& left instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)left).getOp(), LOOKUP_VALID_WDIVMM_BINARY)	
+				&& HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1)) //prevent mv
+				&& left.getInput().get(1) instanceof AggBinaryOp
+				&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+			{
+				Hop W = left.getInput().get(0); 
+				Hop U = left.getInput().get(1).getInput().get(0);
+				Hop V = left.getInput().get(1).getInput().get(1);
+				
+				if( HopRewriteUtils.isTransposeOfItself(right, V) ) 
+				{
+					if( !HopRewriteUtils.isTransposeOperation(V) )
+						V = right;
+					else 
+						V = V.getInput().get(0);
+					
+					boolean mult = ((BinaryOp)left).getOp() == OpOp2.MULT;
+					hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, 
+							  OpOp4.WDIVMM, W, U, V, 2, mult, false);
+					HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+					appliedPattern = true;
+					LOG.debug("Applied simplifyWeightedDivMM2 (line "+hi.getBeginLine()+")");	
+				}
+			}
+			
+			//Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X))
+			if( right instanceof BinaryOp && ((BinaryOp)right).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT
+				&& right.getInput().get(1) instanceof BinaryOp && ((BinaryOp)right.getInput().get(1)).getOp()==OpOp2.MINUS	
+				&& right.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
+                && right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
+				&& HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+			{
+				Hop W = right.getInput().get(0); 
+				Hop U = right.getInput().get(1).getInput().get(0).getInput().get(0);
+				Hop V = right.getInput().get(1).getInput().get(0).getInput().get(1);
+				Hop X = right.getInput().get(1).getInput().get(1);
+				
+				if(    HopRewriteUtils.isNonZeroIndicator(W, X)        //W-X constraint
+				    && HopRewriteUtils.isTransposeOfItself(left, U) )  //t(U)-U constraint
+				{
+					if( !HopRewriteUtils.isTransposeOperation(V) )
+						V = HopRewriteUtils.createTranspose(V);
+					else 
+						V = V.getInput().get(0);
+					
+					hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, 
+							  OpOp4.WDIVMM, X, U, V, 1, true, true);
+					HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+					
+					//add output transpose for efficient target indexing (redundant t() removed by other rewrites)
+					hnew = HopRewriteUtils.createTranspose(hnew);
+					
+					appliedPattern = true;
+					LOG.debug("Applied simplifyWeightedDivMM3 (line "+hi.getBeginLine()+")");					
+				}
+			}	
+			
+			//Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V
+			if( !appliedPattern
+				&& left instanceof BinaryOp && ((BinaryOp)left).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT	
+				&& left.getInput().get(1) instanceof BinaryOp && ((BinaryOp)left.getInput().get(1)).getOp()==OpOp2.MINUS	
+				&& left.getInput().get(1).getInput().get(0) instanceof AggBinaryOp
+                && left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
+				&& HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+			{
+				Hop W = left.getInput().get(0); 
+				Hop U = left.getInput().get(1).getInput().get(0).getInput().get(0);
+				Hop V = left.getInput().get(1).getInput().get(0).getInput().get(1);
+				Hop X = left.getInput().get(1).getInput().get(1);
+				
+				if(    HopRewriteUtils.isNonZeroIndicator(W, X)        //W-X constraint
+					&& HopRewriteUtils.isTransposeOfItself(right, V) )  //V-t(V) constraint
+				{
+					if( !HopRewriteUtils.isTransposeOperation(V) )
+						V = right;
+					else 
+						V = V.getInput().get(0);
+					
+					hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, 
+							  OpOp4.WDIVMM, X, U, V, 2, true, true);
+					HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+					appliedPattern = true;
+					LOG.debug("Applied simplifyWeightedDivMM4 (line "+hi.getBeginLine()+")");	
+				}
+			}
+		}
+		
+		//Pattern 5) (W*(U%*%t(V)))
+		if( !appliedPattern
+			&& hi instanceof BinaryOp && ((BinaryOp)hi).getOp()==LOOKUP_VALID_WDIVMM_BINARY[0] //MULT	
+			&& HopRewriteUtils.isEqualSize(hi.getInput().get(0), hi.getInput().get(1)) //prevent mv
+			&& hi.getDim2() > 1 //not applied for vector-vector mult
+			&& hi.getInput().get(0).getDataType() == DataType.MATRIX 
+			&& hi.getInput().get(0).getDim2() > hi.getInput().get(0).getColsInBlock()
+			&& hi.getInput().get(1) instanceof AggBinaryOp
+			&& (((AggBinaryOp) hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE || hi.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
+			&& HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0),true) ) //BLOCKSIZE CONSTRAINT
+		{
+			Hop W = hi.getInput().get(0); 
+			Hop U = hi.getInput().get(1).getInput().get(0);
+			Hop V = hi.getInput().get(1).getInput().get(1);
+			
+			if( !HopRewriteUtils.isTransposeOperation(V) )
+				V = HopRewriteUtils.createTranspose(V);
+			else 
+				V = V.getInput().get(0);
+				
+			hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE, 
+					  OpOp4.WDIVMM, W, U, V, 0, true, false);
+			HopRewriteUtils.setOutputBlocksizes(hnew, W.getRowsInBlock(), W.getColsInBlock());
+
+			appliedPattern = true;
+			LOG.debug("Applied simplifyWeightedDivMM5 (line "+hi.getBeginLine()+")");	
+		}
+		
+		//relink new hop into original position
+		if( hnew != null ) {
+			HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+			HopRewriteUtils.addChildReference(parent, hnew, pos);
+			hi = hnew;
+		}
+		
+		return hi;
+	}
+
+	/**
+	 * 
+	 * @param parent
+	 * @param hi
+	 * @param pos
+	 * @return
+	 * @throws HopsException
+	 */
+	private Hop simplifyWeightedCrossEntropy(Hop parent, Hop hi, int pos) 
+		throws HopsException
+	{
+		Hop hnew = null;
+		
+		//Pattern 1) sum( X * log(U %*% t(V)))
+		if( hi instanceof AggUnaryOp && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
+			&& ((AggUnaryOp)hi).getOp() == AggOp.SUM     //pattern rooted by sum()
+			&& hi.getInput().get(0) instanceof BinaryOp  //pattern subrooted by binary op
+			&& hi.getInput().get(0).getDim2() > 1   )    //not applied for vector-vector mult
+		{
+			BinaryOp bop = (BinaryOp) hi.getInput().get(0);
+			Hop left = bop.getInput().get(0);
+			Hop right = bop.getInput().get(1);
+			
+			if( bop.getOp()==OpOp2.MULT && left.getDataType()==DataType.MATRIX		
+				&& HopRewriteUtils.isEqualSize(left, right)  //prevent mb
+				&& right instanceof UnaryOp	&& ((UnaryOp)right).getOp()==OpOp1.LOG
+				&& right.getInput().get(0) instanceof AggBinaryOp  //ba gurantees matrices
+				&& HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0),true)) //BLOCKSIZE CONSTRAINT
+			{
+				Hop X = left; 
+				Hop U = right.getInput().get(0).getInput().get(0);
+				Hop V = right.getInput().get(0).getInput().get(1);
+				
+				if( !HopRewriteUtils.isTransposeOperation(V) )
+					V = HopRewriteUtils.createTranspose(V);
+				else 
+					V = V.getInput().get(0);
+					
+				hnew = new QuaternaryOp(hi.getName(), DataType.SCALAR, ValueType.DOUBLE, OpOp4.WCEMM, X, U, V);
+				HopRewriteUtils.setOutputBlocksizes(hnew, X.getRowsInBlock(), X.getColsInBlock());
+					
+				LOG.debug("Applied simplifyWeightedCEMM (line "+hi.getBeginLine()+")");					
+			}
+		}
+		
+		//relink new hop into original position
+		if( hnew != null ) {
+			HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+			HopRewriteUtils.addChildReference(parent, hnew, pos);
+			hi = hnew;
+		}
+		
+		return hi;
+	}
+	
 	/**
 	 * NOTE: dot-product-sum could be also applied to sum(a*b). However, we 
 	 * restrict ourselfs to sum(a^2) and transitively sum(a*a) since a general mm

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/e52e0c0a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 4fc23eb..d2887a2 100644
--- a/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/com/ibm/bi/dml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -29,9 +29,7 @@ import com.ibm.bi.dml.hops.BinaryOp;
 import com.ibm.bi.dml.hops.DataGenOp;
 import com.ibm.bi.dml.hops.Hop;
 import com.ibm.bi.dml.hops.Hop.OpOp1;
-import com.ibm.bi.dml.hops.Hop.OpOp4;
 import com.ibm.bi.dml.hops.IndexingOp;
-import com.ibm.bi.dml.hops.QuaternaryOp;
 import com.ibm.bi.dml.hops.TernaryOp;
 import com.ibm.bi.dml.hops.UnaryOp;
 import com.ibm.bi.dml.hops.Hop.AggOp;
@@ -44,7 +42,6 @@ import com.ibm.bi.dml.hops.LiteralOp;
 import com.ibm.bi.dml.hops.Hop.OpOp2;
 import com.ibm.bi.dml.hops.ParameterizedBuiltinOp;
 import com.ibm.bi.dml.hops.ReorgOp;
-import com.ibm.bi.dml.lops.MapMultChain.ChainType;
 import com.ibm.bi.dml.parser.DataExpression;
 import com.ibm.bi.dml.parser.Statement;
 import com.ibm.bi.dml.parser.Expression.DataType;
@@ -66,7 +63,6 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 	
 	private static OpOp2[] LOOKUP_VALID_DISTRIBUTIVE_BINARY = new OpOp2[]{OpOp2.PLUS, OpOp2.MINUS}; 
 	private static OpOp2[] LOOKUP_VALID_ASSOCIATIVE_BINARY = new OpOp2[]{OpOp2.PLUS, OpOp2.MULT}; 
-	private static OpOp2[] LOOKUP_VALID_WDIVMM_BINARY = new OpOp2[]{OpOp2.MULT, OpOp2.DIV}; 
 	
 	@Override
 	public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) 
@@ -149,10 +145,6 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			hi = removeUnnecessaryTranspose(hop, hi, i);         //e.g., t(t(X))->X; potentially introduced by diag/trace_MM
 			hi = removeUnnecessaryMinus(hop, hi, i);             //e.g., -(-X)->X; potentially introduced by simplfiy binary or dyn rewrites
 			hi = simplifyGroupedAggregate(hi);          	     //e.g., aggregate(target=X,groups=y,fn="count") -> aggregate(target=y,groups=y,fn="count")
-			hi = simplifyWeightedSquaredLoss(hop, hi, i);        //e.g., sum(W * (X - U %*% t(V)) ^ 2) -> wsl(X, U, t(V), W, true)
-			hi = simplifyWeightedSigmoidMMChains(hop, hi, i);    //e.g., W * sigmoid(Y%*%t(X)) -> wsigmoid(W, Y, t(X), type)
-			hi = simplifyWeightedDivMM(hop, hi, i);              //e.g., t(U) %*% (X/(U%*%t(V))) -> wdivmm(X, U, t(V), left)
-			hi = simplifyWeightedCrossEntropy(hop, hi, i);       //e.g., sum(X*log(U%*%t(V))) -> wcemm(X, U, t(V))
 			hi = fuseMinusNzBinaryOperation(hop, hi, i);         //e.g., X-mean*ppred(X,0,!=) -> X -nz mean
 			hi = fuseLogNzBinaryOperation(hop, hi, i);           //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
 			hi = simplifyOuterSeqExpand(hop, hi, i);             //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false)
@@ -1298,6 +1290,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 					LOG.debug("Applied simplifyGroupedAggregateCount");	
 				}
 			}
+<<<<<<< Upstream, based on branch 'master' of https://git-wip-us.apache.org/repos/asf/incubator-systemml.git
 		}
 		
 		return hi;
@@ -1900,6 +1893,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
 			HopRewriteUtils.addChildReference(parent, hnew, pos);
 			hi = hnew;
+=======
+>>>>>>> 04aa86c Fix rewrite 'fuse sum_sq' (after wsloss rewrite), for kmeans_predict 
 		}
 		
 		return hi;