You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:42 UTC

[16/23] systemml git commit: Group together types of e-wise multiply inputs. Comprehensive test.

Group together types of e-wise multiply inputs. Comprehensive test.

The test tests all different kinds of objects multiplied togehter.
The new order of element-wise multiply chains is as follows:

<pre>
    (((unknown * object * frame) * ([least-nnz-matrix * matrix] * most-nnz-matrix))
     * ([least-nnz-row-vector * row-vector] * most-nnz-row-vector))
    * ([[scalars * least-nnz-col-vector] * col-vector] * most-nnz-col-vector)
</pre>


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/999fdfbc
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/999fdfbc
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/999fdfbc

Branch: refs/heads/master
Commit: 999fdfbca9ebd855e031e4b812b64f1b484a33d8
Parents: 6c3e1c5
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 22:10:17 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 22:10:17 2017 -0700

----------------------------------------------------------------------
 ...RewriteElementwiseMultChainOptimization.java | 122 ++++++++++++++---
 ...ElementwiseMultChainOptimizationAllTest.java | 134 +++++++++++++++++++
 .../functions/misc/RewriteEMultChainOpAll.R     |  37 +++++
 .../functions/misc/RewriteEMultChainOpAll.dml   |  31 +++++
 4 files changed, 305 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/999fdfbc/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 9cc8fcd..486072b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -46,6 +46,14 @@ import com.google.common.collect.Multiset;
  *
  * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
  * since foreign parents may rely on the individual results.
+ *
+ * The new order of element-wise multiply chains is as follows:
+ * <pre>
+ *     (((unknown * object * frame) * ([least-nnz-matrix * matrix] * most-nnz-matrix))
+ *      * ([least-nnz-row-vector * row-vector] * most-nnz-row-vector))
+ *     * ([[scalars * least-nnz-col-vector] * col-vector] * most-nnz-col-vector)
+ * </pre>
+ * Identical elements are replaced with powers.
  */
 public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 	@Override
@@ -137,14 +145,90 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 
 		// Construct right-deep EMult tree
 		final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
-		Hop first = constructPower(iterator.next());
 
-		for (int i = 1; i < sorted.size(); i++) {
-			final Hop second = constructPower(iterator.next());
-			first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
-			first.setVisited();
+		Hop next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+		Hop colVectorsScalars = null;
+		while(next != null &&
+				(next.getDataType() == Expression.DataType.SCALAR
+						|| next.getDataType() == Expression.DataType.MATRIX && next.getDim2() == 1))
+		{
+			if( colVectorsScalars == null )
+				colVectorsScalars = next;
+			else {
+				colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT);
+				colVectorsScalars.setVisited();
+			}
+			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+		}
+		// next is not processed and is either null or past col vectors
+
+		Hop rowVectors = null;
+		while(next != null &&
+				(next.getDataType() == Expression.DataType.MATRIX && next.getDim1() == 1))
+		{
+			if( rowVectors == null )
+				rowVectors = next;
+			else {
+				rowVectors = HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT);
+				rowVectors.setVisited();
+			}
+			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+		}
+		// next is not processed and is either null or past row vectors
+
+		Hop matrices = null;
+		while(next != null &&
+				(next.getDataType() == Expression.DataType.MATRIX))
+		{
+			if( matrices == null )
+				matrices = next;
+			else {
+				matrices = HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT);
+				matrices.setVisited();
+			}
+			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
 		}
-		return first;
+		// next is not processed and is either null or past matrices
+
+		Hop other = null;
+		while(next != null)
+		{
+			if( other == null )
+				other = next;
+			else {
+				other = HopRewriteUtils.createBinary(other, next, Hop.OpOp2.MULT);
+				other.setVisited();
+			}
+			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+		}
+		// finished
+
+		// ((other * matrices) * rowVectors) * colVectorsScalars
+		Hop top = null;
+		if( other == null && matrices != null )
+			top = matrices;
+		else if( other != null && matrices == null )
+			top = other;
+		else if( other != null ) { //matrices != null
+			top = HopRewriteUtils.createBinary(other, matrices, Hop.OpOp2.MULT);
+			top.setVisited();
+		}
+
+		if( top == null && rowVectors != null )
+			top = rowVectors;
+		else if( rowVectors != null ) { //top != null
+			top = HopRewriteUtils.createBinary(top, rowVectors, Hop.OpOp2.MULT);
+			top.setVisited();
+		}
+
+		if( top == null && colVectorsScalars != null )
+			top = colVectorsScalars;
+		else if( colVectorsScalars != null ) { //top != null
+			top = HopRewriteUtils.createBinary(top, colVectorsScalars, Hop.OpOp2.MULT);
+			top.setVisited();
+		}
+
+		return top;
 	}
 
 	private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
@@ -154,7 +238,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		hop.setVisited(); // we will visit the leaves' children next
 		if (cnt == 1)
 			return hop;
-		Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+		final Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
 		pow.setVisited();
 		return pow;
 	}
@@ -162,8 +246,8 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 	/**
 	 * A Comparator that orders Hops by their data type, dimention, and sparsity.
 	 * The order is as follows:
-	 * 		scalars > col vectors > row vectors >
-	 *      non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+	 * 		scalars < col vectors < row vectors <
+	 *      non-vector matrices ordered by sparsity (higher nnz last, unknown sparsity last) >
 	 *      other data types.
 	 * Disambiguate by Hop ID.
 	 */
@@ -172,11 +256,11 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		{
 			for (int i = 0, valuesLength = Expression.DataType.values().length; i < valuesLength; i++)
 				switch(Expression.DataType.values()[i]) {
-				case SCALAR: orderDataType[i] = 4; break;
-				case MATRIX: orderDataType[i] = 3; break;
+				case SCALAR: orderDataType[i] = 0; break;
+				case MATRIX: orderDataType[i] = 1; break;
 				case FRAME:  orderDataType[i] = 2; break;
-				case OBJECT: orderDataType[i] = 1; break;
-				case UNKNOWN:orderDataType[i] = 0; break;
+				case OBJECT: orderDataType[i] = 3; break;
+				case UNKNOWN:orderDataType[i] = 4; break;
 				}
 		}
 
@@ -190,15 +274,15 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 			case MATRIX:
 				// two matrices; check for vectors
 				if (o1.getDim2() == 1) { // col vector
-						if (o2.getDim2() != 1) return 1; // col vectors are greatest of matrices
+						if (o2.getDim2() != 1) return -1; // col vectors are greatest of matrices
 						return compareBySparsityThenId(o1, o2); // both col vectors
 				} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
-						return -1; // col vectors are the greatest matrices
+						return 1; // col vectors are the greatest matrices
 				} else if (o1.getDim1() == 1) { // row vector
-						if (o2.getDim1() != 1) return 1; // row vectors greater than non-vectors
+						if (o2.getDim1() != 1) return -1; // row vectors greater than non-vectors
 						return compareBySparsityThenId(o1, o2); // both row vectors
 				} else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
-						return 1; // col vectors greater than non-vectors
+						return 1; // row vectors greater than non-vectors
 				} else { // both non-vectors
 						return compareBySparsityThenId(o1, o2);
 				}
@@ -209,10 +293,10 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		private int compareBySparsityThenId(final Hop o1, final Hop o2) {
 			// the hop with more nnz is first; unknown nnz (-1) last
 			final int c = Long.compare(o1.getNnz(), o2.getNnz());
-			if (c != 0) return c;
+			if (c != 0) return -c;
 			return Long.compare(o1.getHopID(), o2.getHopID());
 		}
-	}.reversed();
+	};
 
 	/**
 	 * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.

http://git-wip-us.apache.org/repos/asf/systemml/blob/999fdfbc/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
new file mode 100644
index 0000000..ba5c78d
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test rewriting `2*X*3*v*5*w*4*z*5*Y*2*v*2*X`, where `v` and `z` are row vectors and `w` is a column vector,
+ * successfully rewrites to `Y*(X^2)*z*(v^2)*w*2400`.
+ */
+public class RewriteElementwiseMultChainOptimizationAllTest extends AutomatedTestBase
+{
+	private static final String TEST_NAME1 = "RewriteEMultChainOpAll";
+	private static final String TEST_DIR = "functions/misc/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationAllTest.class.getSimpleName() + "/";
+	
+	private static final int rows = 123;
+	private static final int cols = 321;
+	private static final double eps = Math.pow(10, -10);
+	
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+	}
+
+	@Test
+	public void testMatrixMultChainOptNoRewritesCP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptNoRewritesSP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptRewritesCP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptRewritesSP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+	}
+
+	private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
+	{	
+		RUNTIME_PLATFORM platformOld = rtplatform;
+		switch( et ){
+			case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+			case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+			default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+		}
+		
+		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+		if( rtplatform == RUNTIME_PLATFORM.SPARK )
+			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+		
+		boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+		OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+		
+		try
+		{
+			TestConfiguration config = getTestConfiguration(testname);
+			loadTestConfiguration(config);
+			
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + testname + ".dml";
+			programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), input("v"), input("z"), input("w"), output("R") };
+			fullRScriptName = HOME + testname + ".R";
+			rCmd = getRCmd(inputDir(), expectedDir());
+
+			double Xsparsity = 0.8, Ysparsity = 0.6;
+			double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+			double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
+			double[][] z = getRandomMatrix(1, cols, -1, 1, Ysparsity, 5);
+			double[][] v = getRandomMatrix(1, cols, -1, 1, Xsparsity, 4);
+			double[][] w = getRandomMatrix(rows, 1, -1, 1, Ysparsity, 6);
+			writeInputMatrixWithMTD("X", X, true);
+			writeInputMatrixWithMTD("Y", Y, true);
+			writeInputMatrixWithMTD("z", z, true);
+			writeInputMatrixWithMTD("v", v, true);
+			writeInputMatrixWithMTD("w", w, true);
+
+			//execute tests
+			runTest(true, false, null, -1); 
+			runRScript(true); 
+			
+			//compare matrices 
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+			HashMap<CellIndex, Double> rfile  = readRMatrixFromFS("R");
+			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+			
+			//check for presence of power operator, if we did a rewrite
+			if( rewrites ) {
+				Assert.assertTrue(heavyHittersContainsSubString("^2"));
+			}
+		}
+		finally {
+			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+			rtplatform = platformOld;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/999fdfbc/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
new file mode 100644
index 0000000..20f76c2
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+# args[1]=""
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
+v = as.matrix(readMM(paste(args[1], "v.mtx", sep="")))
+z = as.matrix(readMM(paste(args[1], "z.mtx", sep="")))
+w = as.matrix(readMM(paste(args[1], "w.mtx", sep="")))
+
+R = 2* X *3* X *5* Y *4*5*2*2* (matrix(1,length(w),1)%*%z) * (matrix(1,length(w),1)%*%v)^2 * (w%*%matrix(1,1,length(v)))
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));

http://git-wip-us.apache.org/repos/asf/systemml/blob/999fdfbc/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
new file mode 100644
index 0000000..90f9242
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1);
+Y = read($2);
+v = read($3);
+z = read($4);
+w = read($5);
+
+R = 2* X *3* v *5* w *4* z *5* Y *2* v *2* X
+
+write(R, $6);
\ No newline at end of file