You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/01/08 20:07:59 UTC
[4/4] incubator-systemml git commit: New simplification rewrite
'pushdown sum on additive binary', for ppca
New simplification rewrite 'pushdown sum on additive binary', for ppca
For example, we now rewrite sum(A+B) -> sum(A)+sum(B) and sum(A-B) ->
sum(A)-sum(B) if dims(A)==dims(B) and dt(A)==dt(B)==MATRIX. This
prevents an unnecessary intermediate, reduces the number of scans from 3
reads / 1 write to two reads, and simplifies binary/unary operations to
pure unary operations that are easier to parallelize. Down the road, we
can generalize this to matrix-vector and matrix-scalar operations too.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/19af3f9b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/19af3f9b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/19af3f9b
Branch: refs/heads/master
Commit: 19af3f9be3736853ff0ccae4e2b074a4b5905c03
Parents: 83a5b42
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Fri Jan 8 11:07:18 2016 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Fri Jan 8 11:07:18 2016 -0800
----------------------------------------------------------------------
.../sysml/hops/rewrite/HopRewriteUtils.java | 30 ++++
.../RewriteAlgebraicSimplificationDynamic.java | 51 ++++++
.../aggregate/PushdownSumBinaryTest.java | 163 +++++++++++++++++++
.../scripts/functions/aggregate/PushdownSum1.R | 34 ++++
.../functions/aggregate/PushdownSum1.dml | 25 +++
.../scripts/functions/aggregate/PushdownSum2.R | 34 ++++
.../functions/aggregate/PushdownSum2.dml | 25 +++
7 files changed, 362 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 95ddf0f..891c0b1 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
import org.apache.sysml.hops.AggBinaryOp;
+import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
@@ -32,6 +33,7 @@ import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.DataGenMethod;
import org.apache.sysml.hops.DataGenOp;
import org.apache.sysml.hops.Hop.DataOpTypes;
+import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.FileFormatTypes;
import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.Hop.ParamBuiltinOp;
@@ -551,6 +553,34 @@ public class HopRewriteUtils
/**
*
+ * @param input
+ * @return
+ */
+ public static AggUnaryOp createSum( Hop input ) {
+ return createAggUnaryOp(input, AggOp.SUM, Direction.RowCol);
+ }
+
+ /**
+ *
+ * @param input
+ * @param op
+ * @param dir
+ * @return
+ */
+ public static AggUnaryOp createAggUnaryOp( Hop input, AggOp op, Direction dir )
+ {
+ DataType dt = (dir==Direction.RowCol) ? DataType.SCALAR : input.getDataType();
+
+ AggUnaryOp auop = new AggUnaryOp(input.getName(), dt, input.getValueType(), op, dir, input);
+ auop.setRowsInBlock(input.getRowsInBlock());
+ auop.setColsInBlock(input.getColsInBlock());
+ auop.refreshSizeInformation();
+
+ return auop;
+ }
+
+ /**
+ *
* @param left
* @param right
* @return
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 7c4a67a..31c394b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -166,6 +166,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hi = simplifyDiagMatrixMult(hop, hi, i); //e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector
hi = simplifySumDiagToTrace(hi); //e.g., sum(diag(X)) -> trace(X); if col vector
hi = pushdownBinaryOperationOnDiag(hop, hi, i); //e.g., diag(X)*7 -> diag(X*7); if col vector
+ hi = pushdownSumOnAdditiveBinary(hop, hi, i); //e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B)
hi = simplifyWeightedSquaredLoss(hop, hi, i); //e.g., sum(W * (X - U %*% t(V)) ^ 2) -> wsl(X, U, t(V), W, true),
hi = simplifyWeightedSigmoidMMChains(hop, hi, i); //e.g., W * sigmoid(Y%*%t(X)) -> wsigmoid(W, Y, t(X), type)
hi = simplifyWeightedDivMM(hop, hi, i); //e.g., t(U) %*% (X/(U%*%t(V))) -> wdivmm(X, U, t(V), left)
@@ -1349,6 +1350,56 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
return hi;
}
+ /**
+ * patterns: sum(A+B)->sum(A)+sum(B); sum(A-B)->sum(A)-sum(B)
+ *
+ * @param parent
+ * @param hi
+ * @param pos
+ * @return
+ */
+ private Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos)
+ {
+ //all patterns headed by fiull sum over binary operation
+ if( hi instanceof AggUnaryOp //full sum root over binaryop
+ && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
+ && ((AggUnaryOp)hi).getOp() == AggOp.SUM
+ && hi.getInput().get(0) instanceof BinaryOp
+ && hi.getInput().get(0).getParent().size()==1 ) //single parent
+ {
+ BinaryOp bop = (BinaryOp) hi.getInput().get(0);
+ Hop left = bop.getInput().get(0);
+ Hop right = bop.getInput().get(1);
+
+ if( HopRewriteUtils.isEqualSize(left, right) //dims(A) == dims(B)
+ && left.getDataType() == DataType.MATRIX
+ && right.getDataType() == DataType.MATRIX )
+ {
+ OpOp2 applyOp = ( bop.getOp() == OpOp2.PLUS //pattern a: sum(A+B)->sum(A)+sum(B)
+ || bop.getOp() == OpOp2.MINUS ) //pattern b: sum(A-B)->sum(A)-sum(B)
+ ? bop.getOp() : null;
+
+ if( applyOp != null ) {
+ //create new subdag sum(A) bop sum(B)
+ AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
+ AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
+ BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);
+
+ //rewire new subdag
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.removeAllChildReferences(hi);
+ HopRewriteUtils.removeAllChildReferences(bop);
+ HopRewriteUtils.addChildReference(parent, newBin, pos);
+
+ hi = newBin;
+
+ LOG.debug("Applied pushdownSumOnAdditiveBinary.");
+ }
+ }
+ }
+
+ return hi;
+ }
/**
* Searches for weighted squared loss expressions and replaces them with a quaternary operator.
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/PushdownSumBinaryTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/PushdownSumBinaryTest.java b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/PushdownSumBinaryTest.java
new file mode 100644
index 0000000..1b87231
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/PushdownSumBinaryTest.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.aggregate;
+
+import java.util.HashMap;
+
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.instructions.Instruction;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+
+/**
+ *
+ */
+public class PushdownSumBinaryTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME1 = "PushdownSum1"; //+
+ private final static String TEST_NAME2 = "PushdownSum2"; //-
+
+ private final static String TEST_DIR = "functions/aggregate/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + PushdownSumBinaryTest.class.getSimpleName() + "/";
+ private final static double eps = 1e-10;
+
+ private final static int rows = 1765;
+ private final static int cols = 19;
+ private final static double sparsity = 0.1;
+
+
+ @Override
+ public void setUp()
+ {
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"C"}));
+ addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"C"}));
+ TestUtils.clearAssertionInformation();
+
+ if (TEST_CACHE_ENABLED) {
+ setOutAndExpectedDeletionDisabled(true);
+ }
+ }
+
+ @BeforeClass
+ public static void init()
+ {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @AfterClass
+ public static void cleanUp()
+ {
+ if (TEST_CACHE_ENABLED) {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+ }
+
+ @Test
+ public void testPushDownSumPlusSP() {
+ runPushdownSumOnBinaryTest(TEST_NAME1, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testPushDownSumMinusSP() {
+ runPushdownSumOnBinaryTest(TEST_NAME2, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testPushDownSumPlusNoRewriteSP() {
+ runPushdownSumOnBinaryTest(TEST_NAME1, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testPushDownSumMinusNoRewriteSP() {
+ runPushdownSumOnBinaryTest(TEST_NAME2, false, ExecType.SPARK);
+ }
+
+ /**
+ *
+ * @param testname
+ * @param type
+ * @param sparse
+ * @param instType
+ */
+ private void runPushdownSumOnBinaryTest( String testname, boolean equiDims, ExecType instType)
+ {
+ //rtplatform for MR
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( instType ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ try
+ {
+ //determine script and function name
+ String TEST_NAME = testname;
+ String TEST_CACHE_DIR = TEST_CACHE_ENABLED ? TEST_NAME + "_" + String.valueOf(equiDims) + "/" : "";
+
+ TestConfiguration config = getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config, TEST_CACHE_DIR);
+
+ // This is for running the junit test the new way, i.e., construct the arguments directly
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-explain","-stats","-args", input("A"), input("B"), output("C") };
+ fullRScriptName = HOME + TEST_NAME + ".R";
+ rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
+
+ //generate actual dataset
+ double[][] A = getRandomMatrix(rows, cols, -1, 1, sparsity, 7);
+ writeInputMatrixWithMTD("A", A, true);
+ double[][] B = getRandomMatrix(rows, equiDims ? cols : 1, -1, 1, sparsity, 73);
+ writeInputMatrixWithMTD("B", B, true);
+
+ //run tests
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare output matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("C");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("C");
+ TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+
+ String lopcode = TEST_NAME.equals(TEST_NAME1) ? "+" : "-";
+ String opcode = equiDims ? lopcode : Instruction.SP_INST_PREFIX+"map"+lopcode;
+ Assert.assertTrue("Non-applied rewrite", Statistics.getCPHeavyHitterOpCodes().contains(opcode));
+ }
+ finally
+ {
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/test/scripts/functions/aggregate/PushdownSum1.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/PushdownSum1.R b/src/test/scripts/functions/aggregate/PushdownSum1.R
new file mode 100644
index 0000000..4eb5c8b
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/PushdownSum1.R
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B <- as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
+if( ncol(B) == 1 ) {
+ B <- B %*% matrix(1,1,ncol(A))
+}
+
+C = as.matrix(sum(A+B));
+
+writeMM(as(C, "CsparseMatrix"), paste(args[2], "C", sep=""));
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/test/scripts/functions/aggregate/PushdownSum1.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/PushdownSum1.dml b/src/test/scripts/functions/aggregate/PushdownSum1.dml
new file mode 100644
index 0000000..e49db15
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/PushdownSum1.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = read($2);
+C = as.matrix(sum(A+B))
+write(C, $3);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/test/scripts/functions/aggregate/PushdownSum2.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/PushdownSum2.R b/src/test/scripts/functions/aggregate/PushdownSum2.R
new file mode 100644
index 0000000..08986ff
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/PushdownSum2.R
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B <- as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
+if( ncol(B) == 1 ) {
+ B <- B %*% matrix(1,1,ncol(A))
+}
+
+C = as.matrix(sum(A-B));
+
+writeMM(as(C, "CsparseMatrix"), paste(args[2], "C", sep=""));
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/test/scripts/functions/aggregate/PushdownSum2.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/PushdownSum2.dml b/src/test/scripts/functions/aggregate/PushdownSum2.dml
new file mode 100644
index 0000000..eec34e7
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/PushdownSum2.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = read($2);
+C = as.matrix(sum(A-B))
+write(C, $3);