You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/07/21 19:54:42 UTC
[2/3] incubator-systemml git commit: [SYSTEMML-766] Improved 'fuse
axpy' rewrite (more patterns, no overlap)
[SYSTEMML-766] Improved 'fuse axpy' rewrite (more patterns, no overlap)
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/973b8635
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/973b8635
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/973b8635
Branch: refs/heads/master
Commit: 973b863579d7bf82505933d3d67fef4517c53eb3
Parents: b233b59
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Wed Jul 20 22:34:46 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Thu Jul 21 12:54:15 2016 -0700
----------------------------------------------------------------------
.../sysml/hops/rewrite/HopRewriteUtils.java | 29 +++++++++
.../RewriteAlgebraicSimplificationDynamic.java | 65 ++++++++++++++++++++
.../RewriteAlgebraicSimplificationStatic.java | 41 ------------
.../misc/RewriteFuseBinaryOpChainTest.java | 40 ++++++++++--
.../misc/RewriteFuseBinaryOpChainTest3.R | 28 +++++++++
.../misc/RewriteFuseBinaryOpChainTest3.dml | 27 ++++++++
6 files changed, 184 insertions(+), 46 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index a5432f1..385a888 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -36,6 +36,7 @@ import org.apache.sysml.hops.Hop.DataOpTypes;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.FileFormatTypes;
import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOp3;
import org.apache.sysml.hops.Hop.ParamBuiltinOp;
import org.apache.sysml.hops.Hop.ReOrgOp;
import org.apache.sysml.hops.Hop.VisitStatus;
@@ -45,6 +46,7 @@ import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.MemoTable;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.ReorgOp;
+import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.parser.DataExpression;
@@ -644,6 +646,22 @@ public class HopRewriteUtils
return datagen;
}
+ /**
+ *
+ * @param mleft
+ * @param smid
+ * @param mright
+ * @param op
+ * @return
+ */
+ public static TernaryOp createTernaryOp(Hop mleft, Hop smid, Hop mright, OpOp3 op) {
+ TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, mleft, smid, mright);
+ ternOp.setRowsInBlock(mleft.getRowsInBlock());
+ ternOp.setColsInBlock(mleft.getColsInBlock());
+ ternOp.refreshSizeInformation();
+ return ternOp;
+ }
+
public static void setOutputBlocksizes( Hop hop, long brlen, long bclen )
{
hop.setRowsInBlock( brlen );
@@ -878,6 +896,17 @@ public class HopRewriteUtils
* @param hop
* @return
*/
+ public static boolean isScalarMatrixBinaryMult( Hop hop ) {
+ return hop instanceof BinaryOp && ((BinaryOp)hop).getOp()==OpOp2.MULT
+ && ((hop.getInput().get(0).getDataType()==DataType.SCALAR && hop.getInput().get(1).getDataType()==DataType.MATRIX)
+ || (hop.getInput().get(0).getDataType()==DataType.MATRIX && hop.getInput().get(1).getDataType()==DataType.SCALAR));
+ }
+
+ /**
+ *
+ * @param hop
+ * @return
+ */
public static boolean isBasic1NSequence(Hop hop)
{
boolean ret = false;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 8205e83..dbde506 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -35,6 +35,7 @@ import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.DataGenMethod;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp3;
import org.apache.sysml.hops.Hop.OpOp4;
import org.apache.sysml.hops.Hop.ReOrgOp;
import org.apache.sysml.hops.HopsException;
@@ -174,6 +175,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hi = simplifyWeightedUnaryMM(hop, hi, i); //e.g., X*exp(U%*%t(V)) -> wumm(X, U, t(V), exp)
hi = simplifyDotProductSum(hop, hi, i); //e.g., sum(v^2) -> t(v)%*%v if ncol(v)==1
hi = fuseSumSquared(hop, hi, i); //e.g., sum(X^2) -> sumSq(X), if ncol(X)>1
+ hi = fuseAxpyBinaryOperationChain(hop, hi, i); //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y)
hi = reorderMinusMatrixMult(hop, hi, i); //e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size
hi = simplifySumMatrixMult(hop, hi, i); //e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product / wsloss
hi = simplifyEmptyBinaryOperation(hop, hi, i); //e.g., X*Y -> matrix(0,nrow(X), ncol(X)) / X+Y->X / X-Y -> X
@@ -2458,6 +2460,69 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
return hi;
}
+
+ /**
+ *
+ * @param parent
+ * @param hi
+ * @param pos
+ * @return
+ * @throws HopsException
+ */
+ private Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos)
+ {
+ //patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY
+ if( hi instanceof BinaryOp
+ && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) )
+ {
+ BinaryOp bop = (BinaryOp) hi;
+ Hop left = bop.getInput().get(0);
+ Hop right = bop.getInput().get(1);
+ Hop ternop = null;
+
+ //pattern (a) X + s*Y -> X +* sY
+ if( bop.getOp() == OpOp2.PLUS && left.getDataType()==DataType.MATRIX
+ && HopRewriteUtils.isScalarMatrixBinaryMult(right)
+ && right.getParent().size() == 1 ) //single consumer s*Y
+ {
+ Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1);
+ Hop mright = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
+ ternop = HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.PLUS_MULT);
+ LOG.debug("Applied fuseAxpyBinaryOperationChain1. (line " +hi.getBeginLine()+")");
+ }
+ //pattern (b) s*Y + X -> X +* sY
+ else if( bop.getOp() == OpOp2.PLUS && right.getDataType()==DataType.MATRIX
+ && HopRewriteUtils.isScalarMatrixBinaryMult(left)
+ && left.getParent().size() == 1 //single consumer s*Y
+ && HopRewriteUtils.isEqualSize(left, right)) //correctness matrix-vector
+ {
+ Hop smid = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1);
+ Hop mright = left.getInput().get( (left.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
+ ternop = HopRewriteUtils.createTernaryOp(right, smid, mright, OpOp3.PLUS_MULT);
+ LOG.debug("Applied fuseAxpyBinaryOperationChain2. (line " +hi.getBeginLine()+")");
+ }
+ //pattern (c) X - s*Y -> X -* sY
+ else if( bop.getOp() == OpOp2.MINUS && left.getDataType()==DataType.MATRIX
+ && HopRewriteUtils.isScalarMatrixBinaryMult(right)
+ && right.getParent().size() == 1 ) //single consumer s*Y
+ {
+ Hop smid = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1);
+ Hop mright = right.getInput().get( (right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
+ ternop = HopRewriteUtils.createTernaryOp(left, smid, mright, OpOp3.MINUS_MULT);
+ LOG.debug("Applied fuseAxpyBinaryOperationChain3. (line " +hi.getBeginLine()+")");
+ }
+
+ //rewire parent-child operators if rewrite applied
+ if( ternop != null ) {
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.addChildReference(parent, ternop, pos);
+ hi = ternop;
+ }
+ }
+
+ return hi;
+ }
+
/**
*
* @param parent
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 9ef2c05..ae9c073 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -162,7 +162,6 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
hi = fuseLogNzBinaryOperation(hop, hi, i); //e.g., ppred(X,0,"!=")*log(X,0.5) -> log_nz(X,0.5)
hi = simplifyOuterSeqExpand(hop, hi, i); //e.g., outer(v, seq(1,m), "==") -> rexpand(v, max=m, dir=row, ignore=true, cast=false)
hi = simplifyTableSeqExpand(hop, hi, i); //e.g., table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true)
- hi = fuseBinaryOperationChain(hop, hi, i); //e.g., (X+s*Y) -> (X+*s Y), (X-s*Y) -> (X-*s Y)
//hi = removeUnecessaryPPred(hop, hi, i); //e.g., ppred(X,X,"==")->matrix(1,rows=nrow(X),cols=ncol(X))
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
@@ -1906,44 +1905,4 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
return hi;
}
-
- /**
- *
- * @param parent
- * @param hi
- * @param pos
- * @return
- * @throws HopsException
- */
- private Hop fuseBinaryOperationChain(Hop parent, Hop hi, int pos) {
- //pattern: X + lamda*Y -> X +* lambda Y
- if( hi instanceof BinaryOp
- && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS)
- && hi.getInput().get(0).getDataType()==DataType.MATRIX
- && hi.getInput().get(1) instanceof BinaryOp
- && ((BinaryOp)hi.getInput().get(1)).getOp()==OpOp2.MULT )
- {
- //Check that the inner binary Op is a product of Scalar times Matrix or viceversa
- Hop innerBinaryOp = hi.getInput().get(1);
- if ( (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR && innerBinaryOp.getInput().get(1).getDataType()==DataType.MATRIX)
- || (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX && innerBinaryOp.getInput().get(1).getDataType()==DataType.SCALAR))
- {
- //check which operand is the Scalar and which is the matrix
- Hop lamda = (innerBinaryOp.getInput().get(0).getDataType()==DataType.SCALAR) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1);
- Hop matrix = (innerBinaryOp.getInput().get(0).getDataType()==DataType.MATRIX) ? innerBinaryOp.getInput().get(0) : innerBinaryOp.getInput().get(1);
-
- OpOp3 op = (((BinaryOp)hi).getOp()==OpOp2.PLUS) ? OpOp3.PLUS_MULT : OpOp3.MINUS_MULT;
- TernaryOp ternOp = new TernaryOp("tmp", DataType.MATRIX, ValueType.DOUBLE, op, hi.getInput().get(0), lamda, matrix);
- HopRewriteUtils.refreshOutputParameters(ternOp, hi.getInput().get(0));
-
- HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
- HopRewriteUtils.addChildReference(parent, ternOp, pos);
-
- LOG.debug("Applied fuseBinaryOperationChain. (line " +hi.getBeginLine()+")");
- return ternOp;
- }
- }
-
- return hi;
- }
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
index 890a3b2..ff85ebc 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java
@@ -40,8 +40,9 @@ import org.apache.sysml.utils.Statistics;
*/
public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
{
- private static final String TEST_NAME1 = "RewriteFuseBinaryOpChainTest1";
- private static final String TEST_NAME2 = "RewriteFuseBinaryOpChainTest2";
+ private static final String TEST_NAME1 = "RewriteFuseBinaryOpChainTest1"; //+* (X+s*Y)
+ private static final String TEST_NAME2 = "RewriteFuseBinaryOpChainTest2"; //-* (X-s*Y)
+ private static final String TEST_NAME3 = "RewriteFuseBinaryOpChainTest3"; //+* (s*Y+X)
private static final String TEST_DIR = "functions/misc/";
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteFuseBinaryOpChainTest.class.getSimpleName() + "/";
@@ -53,6 +54,7 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
TestUtils.clearAssertionInformation();
addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+ addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
}
@Test
@@ -60,7 +62,6 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
testFuseBinaryChain( TEST_NAME1, false, ExecType.CP );
}
-
@Test
public void testFuseBinaryPlusRewriteCP() {
testFuseBinaryChain( TEST_NAME1, true, ExecType.CP);
@@ -77,6 +78,16 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
}
@Test
+ public void testFuseBinaryPlus2NoRewriteCP() {
+ testFuseBinaryChain( TEST_NAME3, false, ExecType.CP );
+ }
+
+ @Test
+ public void testFuseBinaryPlus2RewriteCP() {
+ testFuseBinaryChain( TEST_NAME3, true, ExecType.CP );
+ }
+
+ @Test
public void testFuseBinaryPlusNoRewriteSP() {
testFuseBinaryChain( TEST_NAME1, false, ExecType.SPARK );
}
@@ -97,6 +108,16 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
}
@Test
+ public void testFuseBinaryPlus2NoRewriteSP() {
+ testFuseBinaryChain( TEST_NAME3, false, ExecType.SPARK );
+ }
+
+ @Test
+ public void testFuseBinaryPlus2RewriteSP() {
+ testFuseBinaryChain( TEST_NAME3, true, ExecType.SPARK );
+ }
+
+ @Test
public void testFuseBinaryPlusNoRewriteMR() {
testFuseBinaryChain( TEST_NAME1, false, ExecType.MR );
}
@@ -116,6 +137,15 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
testFuseBinaryChain( TEST_NAME2, true, ExecType.MR );
}
+ @Test
+ public void testFuseBinaryPlus2NoRewriteMR() {
+ testFuseBinaryChain( TEST_NAME3, false, ExecType.MR );
+ }
+
+ @Test
+ public void testFuseBinaryPlus2RewriteMR() {
+ testFuseBinaryChain( TEST_NAME3, true, ExecType.MR );
+ }
/**
*
@@ -162,8 +192,8 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase
//check for applies rewrites
if( rewrites && instType!=ExecType.MR ) {
String prefix = (instType==ExecType.SPARK) ? Instruction.SP_INST_PREFIX : "";
- Assert.assertTrue("Rewrite not applied.",Statistics.getCPHeavyHitterOpCodes()
- .contains(testname.equals(TEST_NAME1) ? prefix+"+*" : prefix+"-*" ));
+ String opcode = (testname.equals(TEST_NAME1)||testname.equals(TEST_NAME3)) ? prefix+"+*" : prefix+"-*";
+ Assert.assertTrue("Rewrite not applied.",Statistics.getCPHeavyHitterOpCodes().contains(opcode));
}
}
finally
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R
new file mode 100644
index 0000000..5ae1642
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.R
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X=matrix(1,10,10)
+Y=matrix(1,10,10)
+lamda=7
+S=lamda*Y+X
+writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep=""));
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/973b8635/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml
new file mode 100644
index 0000000..af84884
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest3.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X=matrix(1,rows=10,cols=10)
+Y=matrix(1,rows=10,cols=10)
+if(1==1){}
+lamda=7
+S=lamda*Y+X
+write(S,$1)
\ No newline at end of file