You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/07/21 19:54:42 UTC

[2/3] incubator-systemml git commit: [SYSTEMML-766] Improved 'fuse axpy' rewrite (more patterns, no overlap)

[SYSTEMML-766] Improved 'fuse axpy' rewrite (more patterns, no overlap)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/973b8635
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/973b8635
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/973b8635

Branch: refs/heads/master
Commit: 973b863579d7bf82505933d3d67fef4517c53eb3
Parents: b233b59
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Wed Jul 20 22:34:46 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Thu Jul 21 12:54:15 2016 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     | 29 +++++++++
 .../RewriteAlgebraicSimplificationDynamic.java  | 65 ++++++++++++++++++++
 .../RewriteAlgebraicSimplificationStatic.java   | 41 ------------
 .../misc/RewriteFuseBinaryOpChainTest.java      | 40 ++++++++++--
 .../misc/RewriteFuseBinaryOpChainTest3.R        | 28 +++++++++
 .../misc/RewriteFuseBinaryOpChainTest3.dml      | 27 ++++++++
 6 files changed, 184 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index a5432f1..385a888 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -36,6 +36,7 @@ import org.apache.sysml.hops.Hop.DataOpTypes;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.FileFormatTypes;
 import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOp3;
 import org.apache.sysml.hops.Hop.ParamBuiltinOp;
 import org.apache.sysml.hops.Hop.ReOrgOp;
 import org.apache.sysml.hops.Hop.VisitStatus;
@@ -45,6 +46,7 @@ import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.MemoTable;
 import org.apache.sysml.hops.ParameterizedBuiltinOp;
 import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.parser.DataExpression;
@@ -644,6 +646,22 @@ public class HopRewriteUtils
 		return datagen;
 	}
 	
+	/**
+	 * 
+	 * @param mleft
+	 * @param smid
+	 * @param mright
+	 * @param op
+	 * @return
+	 */
+	public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop mright, OpOp3 op) {
+		TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, mleft, smid, mright);
+		ternOp.setRowsInBlock(mleft.getRowsInBlock());
+		ternOp.setColsInBlock(mleft.getColsInBlock());
+		ternOp.refreshSizeInformation();
+		return ternOp;
+	}
+	
 	public static void setOutputBlocksizes( Hop hop, long brlen, long bclen )
 	{
 		hop.setRowsInBlock( brlen );
@@ -878,6 +896,17 @@ public class HopRewriteUtils
 	 * @param hop
 	 * @return
 	 */
+	public static boolean isScalarMatrixBinaryMult( Hop hop ) {
+		return hop instanceof BinaryOp && ((BinaryOp)hop).getOp()==OpOp2.MULT
+			&& ((hop.getInput().get(0).getDataType()==DataType.SCALAR && hop.getInput().get(1).getDataType()==DataType.MATRIX)
+			|| (hop.getInput().get(0).getDataType()==DataType.MATRIX && hop.getInput().get(1).getDataType()==DataType.SCALAR));
+	}
+	
+	/**
+	 * 
+	 * @param hop
+	 * @return
+	 */
 	public static boolean isBasic1NSequence(Hop hop)
 	{
 		boolean ret = false;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 8205e83..dbde506 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -35,6 +35,7 @@ import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.DataGenMethod;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp3;
 import org.apache.sysml.hops.Hop.OpOp4;
 import org.apache.sysml.hops.Hop.ReOrgOp;
 import org.apache.sysml.hops.HopsException;
@@ -174,6 +175,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			hi = simplifyWeightedUnaryMM(hop, hi, i);         //e.g., X*exp(U%*%t(V)) -> wumm(X, U, t(V), exp)
 			hi = simplifyDotProductSum(hop, hi, i);           //e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1 
 			hi = fuseSumSquared(hop, hi, i);                  //e.g., sum(X^2) -> sumSq(X), if ncol(X)>1
+			hi = fuseAxpyBinaryOperationChain(hop, hi, i);    //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y) 	
 			hi = reorderMinusMatrixMult(hop, hi, i);          //e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size
 			hi = simplifySumMatrixMult(hop, hi, i);           //e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product / wsloss
 			hi = simplifyEmptyBinaryOperation(hop, hi, i);    //e.g., X*Y -> matrix(0,nrow(X), ncol(X)) / X+Y->X / X-Y -> X
@@ -2458,6 +2460,69 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 		return hi;
 	}
 	
+
+	/**
+	 * 
+	 * @param parent
+	 * @param hi
+	 * @param pos
+	 * @return
+	 * @throws HopsException
+	 */
+	private Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) 
+	{
+		//patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY		
+		if( hi instanceof BinaryOp 
+			&& (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) )
+		{
+			BinaryOp bop = (BinaryOp) hi;
+			Hop left = bop.getInput().get(0);
+			Hop right = bop.getInput().get(1);
+			Hop ternop = null;
+			
+			//pattern (a) X + s*Y -> X +* sY
+			if( bop.getOp() == OpOp2.PLUS && left.getDataType()==DataType.MATRIX 
+				&& HopRewriteUtils.isScalarMatrixBinaryMult(right)
+				&& right.getParent().size() == 1 )           //single consumer s*Y
+			{
+				Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); 
+				Hop mright = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
+				ternop = HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.PLUS_MULT);
+				LOG.debug("Applied fuseAxpyBinaryOperationChain1. (line " +hi.getBeginLine()+")");
+			}
+			//pattern (b) s*Y + X -> X +* sY
+			else if( bop.getOp() == OpOp2.PLUS && right.getDataType()==DataType.MATRIX 
+				&& HopRewriteUtils.isScalarMatrixBinaryMult(left)
+				&& left.getParent().size() == 1              //single consumer s*Y
+				&& HopRewriteUtils.isEqualSize(left, right)) //correctness matrix-vector
+			{
+				Hop smid = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); 
+				Hop mright = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
+				ternop = HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT);
+				LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " +hi.getBeginLine()+")");	
+			}
+			//pattern (c) X - s*Y -> X -* sY
+			else if( bop.getOp() == OpOp2.MINUS && left.getDataType()==DataType.MATRIX 
+				&& HopRewriteUtils.isScalarMatrixBinaryMult(right)
+				&& right.getParent().size() == 1 )           //single consumer s*Y
+			{
+				Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1); 
+				Hop mright = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
+				ternop = HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.MINUS_MULT);
+				LOG.debug("Applied fuseAxpyBinaryOperationChain3. (line " +hi.getBeginLine()+")");
+			}
+			
+			//rewire parent-child operators if rewrite applied
+			if( ternop != null ) { 
+				HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+				HopRewriteUtils.addChildReference(parent, ternop, pos);
+				hi = ternop;
+			}
+		}
+		
+		return hi;
+	}
+	
 	/**
 	 * 
 	 * @param parent

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 9ef2c05..ae9c073 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -162,7 +162,6 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			hi = fuseLogNzBinaryOperation(hop, hi, i);           //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
 			hi = simplifyOuterSeqExpand(hop, hi, i);             //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false)
 			hi = simplifyTableSeqExpand(hop, hi, i);             //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true)
-			hi = fuseBinaryOperationChain(hop, hi, i);			 //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y) 	
 			//hi = removeUnecessaryPPred(hop, hi, i);            //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
 
 			//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
@@ -1906,44 +1905,4 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 		
 		return hi;
 	}
-
-	/**
-	 * 
-	 * @param parent
-	 * @param hi
-	 * @param pos
-	 * @return
-	 * @throws HopsException
-	 */
-	private Hop fuseBinaryOperationChain(Hop parent, Hop hi, int pos) {
-		//pattern: X + lamda*Y -> X +* lambda Y		
-		if( hi instanceof BinaryOp 
-			&& (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) 
-			&& hi.getInput().get(0).getDataType()==DataType.MATRIX 
-			&& hi.getInput().get(1) instanceof BinaryOp 
-			&& ((BinaryOp)hi.getInput().get(1)).getOp()==OpOp2.MULT )
-		{
-			//Check that the inner binary Op is a product of Scalar times Matrix or viceversa
-			Hop innerBinaryOp =  hi.getInput().get(1);
-			if ( (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR && innerBinaryOp.getInput().get(1).getDataType()==DataType.MATRIX) 
-					|| (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX && innerBinaryOp.getInput().get(1).getDataType()==DataType.SCALAR))
-			{
-				//check which operand is the Scalar and which is the matrix
-				Hop lamda = (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1); 
-				Hop matrix = (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1);
-
-				OpOp3 op = (((BinaryOp)hi).getOp()==OpOp2.PLUS) ? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT;
-				TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, hi.getInput().get(0), lamda, matrix);
-				HopRewriteUtils.refreshOutputParameters(ternOp, hi.getInput().get(0));
-				
-				HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
-				HopRewriteUtils.addChildReference(parent, ternOp, pos);
-				
-				LOG.debug("Applied fuseBinaryOperationChain. (line " +hi.getBeginLine()+")");
-				return ternOp;
-			}
-		}
-		
-		return hi;
-	}
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
index 890a3b2..ff85ebc 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
@@ -40,8 +40,9 @@ import org.apache.sysml.utils.Statistics;
  */
 public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase 
 {
-	private static final String TEST_NAME1 = "RewriteFuseBinaryOpChainTest1";
-	private static final String TEST_NAME2 = "RewriteFuseBinaryOpChainTest2";
+	private static final String TEST_NAME1 = "RewriteFuseBinaryOpChainTest1"; //+* (X+s*Y)
+	private static final String TEST_NAME2 = "RewriteFuseBinaryOpChainTest2"; //-* (X-s*Y) 
+	private static final String TEST_NAME3 = "RewriteFuseBinaryOpChainTest3"; //+* (s*Y+X)
 
 	private static final String TEST_DIR = "functions/misc/";
 	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteFuseBinaryOpChainTest.class.getSimpleName() + "/";
@@ -53,6 +54,7 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
 		TestUtils.clearAssertionInformation();
 		addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
 		addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+		addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
 	}
 	
 	@Test
@@ -60,7 +62,6 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
 		testFuseBinaryChain( TEST_NAME1, false, ExecType.CP );
 	}
 	
-	
 	@Test
 	public void testFuseBinaryPlusRewriteCP() {
 		testFuseBinaryChain( TEST_NAME1, true, ExecType.CP);
@@ -77,6 +78,16 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
 	}
 	
 	@Test
+	public void testFuseBinaryPlus2NoRewriteCP() {
+		testFuseBinaryChain( TEST_NAME3, false, ExecType.CP );
+	}
+	
+	@Test
+	public void testFuseBinaryPlus2RewriteCP() {
+		testFuseBinaryChain( TEST_NAME3, true, ExecType.CP );
+	}
+	
+	@Test
 	public void testFuseBinaryPlusNoRewriteSP() {
 		testFuseBinaryChain( TEST_NAME1, false, ExecType.SPARK );
 	}
@@ -97,6 +108,16 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
 	}
 	
 	@Test
+	public void testFuseBinaryPlus2NoRewriteSP() {
+		testFuseBinaryChain( TEST_NAME3, false, ExecType.SPARK );
+	}
+	
+	@Test
+	public void testFuseBinaryPlus2RewriteSP() {
+		testFuseBinaryChain( TEST_NAME3, true, ExecType.SPARK );
+	}
+	
+	@Test
 	public void testFuseBinaryPlusNoRewriteMR() {
 		testFuseBinaryChain( TEST_NAME1, false, ExecType.MR );
 	}
@@ -116,6 +137,15 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
 		testFuseBinaryChain( TEST_NAME2, true, ExecType.MR );
 	}
 	
+	@Test
+	public void testFuseBinaryPlus2NoRewriteMR() {
+		testFuseBinaryChain( TEST_NAME3, false, ExecType.MR );
+	}
+	
+	@Test
+	public void testFuseBinaryPlus2RewriteMR() {
+		testFuseBinaryChain( TEST_NAME3, true, ExecType.MR );
+	}
 	
 	/**
 	 * 
@@ -162,8 +192,8 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
 			//check for applies rewrites
 			if( rewrites && instType!=ExecType.MR  ) {
 				String prefix = (instType==ExecType.SPARK) ? Instruction.SP_INST_PREFIX  : "";
-				Assert.assertTrue("Rewrite not applied.",Statistics.getCPHeavyHitterOpCodes()
-						.contains(testname.equals(TEST_NAME1) ? prefix+"+*" : prefix+"-*" ));
+				String opcode = (testname.equals(TEST_NAME1)||testname.equals(TEST_NAME3)) ? prefix+"+*" : prefix+"-*";
+				Assert.assertTrue("Rewrite not applied.",Statistics.getCPHeavyHitterOpCodes().contains(opcode));
 			}
 		}
 		finally

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R
new file mode 100644
index 0000000..5ae1642
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X=matrix(1,10,10)
+Y=matrix(1,10,10)
+lamda=7
+S=lamda*Y+X
+writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml
new file mode 100644
index 0000000..af84884
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X=matrix(1,rows=10,cols=10)
+Y=matrix(1,rows=10,cols=10)
+if(1==1){}
+lamda=7
+S=lamda*Y+X
+write(S,$1)
\ No newline at end of file