You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/01/08 20:07:59 UTC

[4/4] incubator-systemml git commit: New simplification rewrite 'pushdown sum on additive binary', for ppca

New simplification rewrite 'pushdown sum on additive binary', for ppca

For example, we now rewrite sum(A+B) -> sum(A)+sum(B) and sum(A-B) ->
sum(A)-sum(B) if dims(A)==dims(B) and dt(A)==dt(B)==MATRIX. This
prevents an unnecessary intermediate, reduces the number of scans from 3
reads / 1 write to two reads, and simplifies binary/unary operations to
pure unary operations that are easier to parallelize. Down the road, we
can generalize this to matrix-vector and matrix-scalar operations too.

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/19af3f9b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/19af3f9b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/19af3f9b

Branch: refs/heads/master
Commit: 19af3f9be3736853ff0ccae4e2b074a4b5905c03
Parents: 83a5b42
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Fri Jan 8 11:07:18 2016 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Fri Jan 8 11:07:18 2016 -0800

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  30 ++++
 .../RewriteAlgebraicSimplificationDynamic.java  |  51 ++++++
 .../aggregate/PushdownSumBinaryTest.java        | 163 +++++++++++++++++++
 .../scripts/functions/aggregate/PushdownSum1.R  |  34 ++++
 .../functions/aggregate/PushdownSum1.dml        |  25 +++
 .../scripts/functions/aggregate/PushdownSum2.R  |  34 ++++
 .../functions/aggregate/PushdownSum2.dml        |  25 +++
 7 files changed, 362 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 95ddf0f..891c0b1 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.hops.AggBinaryOp;
+import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.DataOp;
 import org.apache.sysml.hops.Hop;
@@ -32,6 +33,7 @@ import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.DataGenMethod;
 import org.apache.sysml.hops.DataGenOp;
 import org.apache.sysml.hops.Hop.DataOpTypes;
+import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.FileFormatTypes;
 import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.Hop.ParamBuiltinOp;
@@ -551,6 +553,34 @@ public class HopRewriteUtils
 	
 	/**
 	 * 
+	 * @param input
+	 * @return
+	 */
+	public static AggUnaryOp createSum( Hop input ) {
+		return createAggUnaryOp(input, AggOp.SUM, Direction.RowCol);
+	}
+	
+	/**
+	 * 
+	 * @param input
+	 * @param op
+	 * @param dir
+	 * @return
+	 */
+	public static AggUnaryOp createAggUnaryOp( Hop input, AggOp op, Direction dir )
+	{
+		DataType dt = (dir==Direction.RowCol) ? DataType.SCALAR : input.getDataType();
+		
+		AggUnaryOp auop = new AggUnaryOp(input.getName(), dt, input.getValueType(), op, dir, input);
+		auop.setRowsInBlock(input.getRowsInBlock());
+		auop.setColsInBlock(input.getColsInBlock());
+		auop.refreshSizeInformation();
+		
+		return auop;
+	}
+	
+	/**
+	 * 
 	 * @param left
 	 * @param right
 	 * @return

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 7c4a67a..31c394b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -166,6 +166,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			hi = simplifyDiagMatrixMult(hop, hi, i);          //e.g., diag(X%*%Y)->rowSums(X*t(Y)); if col vector
 			hi = simplifySumDiagToTrace(hi);                  //e.g., sum(diag(X)) -> trace(X); if col vector
 			hi = pushdownBinaryOperationOnDiag(hop, hi, i);   //e.g., diag(X)*7 -> diag(X*7); if col vector
+			hi = pushdownSumOnAdditiveBinary(hop, hi, i);     //e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B)
 			hi = simplifyWeightedSquaredLoss(hop, hi, i);     //e.g., sum(W * (X - U %*% t(V)) ^ 2) -> wsl(X, U, t(V), W, true), 
 			hi = simplifyWeightedSigmoidMMChains(hop, hi, i); //e.g., W * sigmoid(Y%*%t(X)) -> wsigmoid(W, Y, t(X), type)
 			hi = simplifyWeightedDivMM(hop, hi, i);           //e.g., t(U) %*% (X/(U%*%t(V))) -> wdivmm(X, U, t(V), left)
@@ -1349,6 +1350,56 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 		return hi;
 	}
 	
+	/**
+	 * patterns: sum(A+B)->sum(A)+sum(B); sum(A-B)->sum(A)-sum(B)
+	 * 
+	 * @param parent
+	 * @param hi
+	 * @param pos
+	 * @return
+	 */
+	private Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos) 
+	{
+		//all patterns headed by fiull sum over binary operation
+		if(    hi instanceof AggUnaryOp //full sum root over binaryop
+			&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol
+			&& ((AggUnaryOp)hi).getOp() == AggOp.SUM 
+			&& hi.getInput().get(0) instanceof BinaryOp   
+			&& hi.getInput().get(0).getParent().size()==1 ) //single parent
+		{
+			BinaryOp bop = (BinaryOp) hi.getInput().get(0);
+			Hop left = bop.getInput().get(0);
+			Hop right = bop.getInput().get(1);
+			
+			if( HopRewriteUtils.isEqualSize(left, right)  //dims(A) == dims(B)
+				&& left.getDataType() == DataType.MATRIX
+				&& right.getDataType() == DataType.MATRIX )			
+			{
+				OpOp2 applyOp = ( bop.getOp() == OpOp2.PLUS //pattern a: sum(A+B)->sum(A)+sum(B)
+						|| bop.getOp() == OpOp2.MINUS )     //pattern b: sum(A-B)->sum(A)-sum(B)
+						? bop.getOp() : null;
+				
+				if( applyOp != null ) {
+					//create new subdag sum(A) bop sum(B)
+					AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
+					AggUnaryOp sum2 = HopRewriteUtils.createSum(right);					
+					BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);
+
+					//rewire new subdag
+					HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);		
+					HopRewriteUtils.removeAllChildReferences(hi);
+					HopRewriteUtils.removeAllChildReferences(bop);
+					HopRewriteUtils.addChildReference(parent, newBin, pos);
+					
+					hi = newBin;
+					
+					LOG.debug("Applied pushdownSumOnAdditiveBinary.");
+				}				
+			}
+		}
+	
+		return hi;
+	}
 
 	/**
 	 * Searches for weighted squared loss expressions and replaces them with a quaternary operator. 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/PushdownSumBinaryTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/PushdownSumBinaryTest.java b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/PushdownSumBinaryTest.java
new file mode 100644
index 0000000..1b87231
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/PushdownSumBinaryTest.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.aggregate;
+
+import java.util.HashMap;
+
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.instructions.Instruction;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+
+/**
+ * 
+ */
+public class PushdownSumBinaryTest extends AutomatedTestBase 
+{
+	private final static String TEST_NAME1 = "PushdownSum1"; //+
+	private final static String TEST_NAME2 = "PushdownSum2"; //-
+	
+	private final static String TEST_DIR = "functions/aggregate/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + PushdownSumBinaryTest.class.getSimpleName() + "/";
+	private final static double eps = 1e-10;
+	
+	private final static int rows = 1765;
+	private final static int cols = 19;
+	private final static double sparsity = 0.1;
+	
+	
+	@Override
+	public void setUp() 
+	{
+		addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"C"})); 
+		addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"C"})); 
+		TestUtils.clearAssertionInformation();
+
+		if (TEST_CACHE_ENABLED) {
+			setOutAndExpectedDeletionDisabled(true);
+		}
+	}
+
+	@BeforeClass
+	public static void init()
+	{
+		TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+	}
+
+	@AfterClass
+	public static void cleanUp()
+	{
+		if (TEST_CACHE_ENABLED) {
+			TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+		}
+	}
+
+	@Test
+	public void testPushDownSumPlusSP() {
+		runPushdownSumOnBinaryTest(TEST_NAME1, true, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testPushDownSumMinusSP() {
+		runPushdownSumOnBinaryTest(TEST_NAME2, true, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testPushDownSumPlusNoRewriteSP() {
+		runPushdownSumOnBinaryTest(TEST_NAME1, false, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testPushDownSumMinusNoRewriteSP() {
+		runPushdownSumOnBinaryTest(TEST_NAME2, false, ExecType.SPARK);
+	}
+		
+	/**
+	 * 
+	 * @param testname
+	 * @param type
+	 * @param sparse
+	 * @param instType
+	 */
+	private void runPushdownSumOnBinaryTest( String testname, boolean equiDims, ExecType instType) 
+	{
+		//rtplatform for MR
+		RUNTIME_PLATFORM platformOld = rtplatform;
+		switch( instType ){
+			case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+			case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+			default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+		}
+	
+		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+		if( rtplatform == RUNTIME_PLATFORM.SPARK )
+			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+			
+		try
+		{
+			//determine script and function name
+			String TEST_NAME = testname;			
+			String TEST_CACHE_DIR = TEST_CACHE_ENABLED ? TEST_NAME + "_" + String.valueOf(equiDims) + "/" : "";
+			
+			TestConfiguration config = getTestConfiguration(TEST_NAME);
+			loadTestConfiguration(config, TEST_CACHE_DIR);
+			
+			// This is for running the junit test the new way, i.e., construct the arguments directly
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			programArgs = new String[]{"-explain","-stats","-args", input("A"), input("B"), output("C") };
+			fullRScriptName = HOME + TEST_NAME + ".R";
+			rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
+	
+			//generate actual dataset 
+			double[][] A = getRandomMatrix(rows, cols, -1, 1, sparsity, 7); 
+			writeInputMatrixWithMTD("A", A, true);
+			double[][] B = getRandomMatrix(rows, equiDims ? cols : 1, -1, 1, sparsity, 73); 
+			writeInputMatrixWithMTD("B", B, true);
+			
+			//run tests
+			runTest(true, false, null, -1); 
+			runRScript(true); 
+			
+			//compare output matrices
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("C");
+			HashMap<CellIndex, Double> rfile  = readRMatrixFromFS("C");
+			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+			
+			String lopcode = TEST_NAME.equals(TEST_NAME1) ? "+" : "-";
+			String opcode = equiDims ? lopcode : Instruction.SP_INST_PREFIX+"map"+lopcode;
+			Assert.assertTrue("Non-applied rewrite", Statistics.getCPHeavyHitterOpCodes().contains(opcode));	
+		}
+		finally
+		{
+			rtplatform = platformOld;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/test/scripts/functions/aggregate/PushdownSum1.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/PushdownSum1.R b/src/test/scripts/functions/aggregate/PushdownSum1.R
new file mode 100644
index 0000000..4eb5c8b
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/PushdownSum1.R
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B <- as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
+if( ncol(B) == 1 ) {
+  B <- B %*% matrix(1,1,ncol(A))
+}
+
+C = as.matrix(sum(A+B));
+
+writeMM(as(C, "CsparseMatrix"), paste(args[2], "C", sep=""));

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/test/scripts/functions/aggregate/PushdownSum1.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/PushdownSum1.dml b/src/test/scripts/functions/aggregate/PushdownSum1.dml
new file mode 100644
index 0000000..e49db15
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/PushdownSum1.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = read($2);
+C = as.matrix(sum(A+B))
+write(C, $3);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/test/scripts/functions/aggregate/PushdownSum2.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/PushdownSum2.R b/src/test/scripts/functions/aggregate/PushdownSum2.R
new file mode 100644
index 0000000..08986ff
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/PushdownSum2.R
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B <- as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
+if( ncol(B) == 1 ) {
+  B <- B %*% matrix(1,1,ncol(A))
+}
+
+C = as.matrix(sum(A-B));
+
+writeMM(as(C, "CsparseMatrix"), paste(args[2], "C", sep=""));

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/19af3f9b/src/test/scripts/functions/aggregate/PushdownSum2.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/PushdownSum2.dml b/src/test/scripts/functions/aggregate/PushdownSum2.dml
new file mode 100644
index 0000000..eec34e7
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/PushdownSum2.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = read($2);
+C = as.matrix(sum(A-B))
+write(C, $3);