You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/07/17 03:35:25 UTC

[3/3] incubator-systemml git commit: [SYSTEMML-766] Fix rewrite 'fuse binary axpy' (missing blocksize info)

[SYSTEMML-766] Fix rewrite 'fuse binary axpy' (missing blocksize info)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2b7fdb2b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2b7fdb2b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2b7fdb2b

Branch: refs/heads/master
Commit: 2b7fdb2b36df0aedc6b92bf138e7f0074eed7762
Parents: cbc4509
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sat Jul 16 20:32:28 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sat Jul 16 20:32:28 2016 -0700

----------------------------------------------------------------------
 .../rewrite/RewriteAlgebraicSimplificationStatic.java  | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2b7fdb2b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 43d5791..816b55a 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -165,7 +165,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			hi = fuseLogNzBinaryOperation(hop, hi, i);           //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
 			hi = simplifyOuterSeqExpand(hop, hi, i);             //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false)
 			hi = simplifyTableSeqExpand(hop, hi, i);             //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true)
-			hi = fuseBinaryOperationChain(hop, hi, i);			 //e.g., X + lamda*Y -> X +* lambda Y	
+			hi = fuseBinaryOperationChain(hop, hi, i);			 //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y) 	
 			//hi = removeUnecessaryPPred(hop, hi, i);            //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
 
 			//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
@@ -1922,11 +1922,11 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 		//pattern: X + lamda*Y -> X +* lambda Y		
 		if( hi instanceof BinaryOp
 				&& (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) 
-				&& ((BinaryOp)hi).getInput().get(0).getDataType()==DataType.MATRIX && ((BinaryOp)hi).getInput().get(1) instanceof BinaryOp 
+				&& hi.getInput().get(0).getDataType()==DataType.MATRIX && hi.getInput().get(1) instanceof BinaryOp 
 				&& (DMLScript.rtplatform == RUNTIME_PLATFORM.SINGLE_NODE || OptimizerUtils.isSparkExecutionMode()) )
 		{
 			//Check that the inner binary Op is a product of Scalar times Matrix or viceversa
-			Hop innerBinaryOp =  ((BinaryOp)hi).getInput().get(1);
+			Hop innerBinaryOp =  hi.getInput().get(1);
 			if ( (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR && innerBinaryOp.getInput().get(1).getDataType()==DataType.MATRIX) 
 					|| (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX && innerBinaryOp.getInput().get(1).getDataType()==DataType.SCALAR))
 			{
@@ -1934,8 +1934,9 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 				Hop lamda = (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1); 
 				Hop matrix = (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1);
 
-				OpOp3 operator = (((BinaryOp)hi).getOp()==OpOp2.PLUS) ? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT;
-				TernaryOp ternOp=new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, operator, ((BinaryOp)hi).getInput().get(0), lamda, matrix);
+				OpOp3 op = (((BinaryOp)hi).getOp()==OpOp2.PLUS) ? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT;
+				TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, hi.getInput().get(0), lamda, matrix);
+				HopRewriteUtils.refreshOutputParameters(ternOp, hi.getInput().get(0));
 				
 				HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
 				HopRewriteUtils.addChildReference(parent, ternOp, pos);
@@ -1944,7 +1945,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 				return ternOp;
 			}
 		}
+		
 		return hi;
-	
 	}
 }