You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:30 UTC

[04/23] systemml git commit: Correct ordering of e-mult chain rewrites.

Correct ordering of e-mult chain rewrites.

Sorting scalars, vectors, matrices appropriately and by sparsity (when nnz information is available).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8b832f62
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8b832f62
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8b832f62

Branch: refs/heads/master
Commit: 8b832f624dd23ba0006672c444cf6f0649a6e753
Parents: ff8c836
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 20:48:57 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:21 2017 -0700

----------------------------------------------------------------------
 .../apache/sysml/hops/rewrite/RewriteEMult.java | 78 ++++++++++++++++++--
 .../functions/misc/RewriteEMultChainTest.java   |  7 +-
 2 files changed, 74 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/8b832f62/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index 66da6fa..d483a08 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -32,6 +32,7 @@ import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.parser.Expression;
 
 import com.google.common.collect.HashMultiset;
 import com.google.common.collect.Multiset;
@@ -125,13 +126,13 @@ public class RewriteEMult extends HopRewriteRule {
 		}
 		// sorted contains all leaves, sorted by data type, stripped from their parents
 
-		// Construct left-deep EMult tree
-		Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+		// Construct right-deep EMult tree
+		final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
 		Hop first = constructPower(iterator.next());
 
 		for (int i = 1; i < sorted.size(); i++) {
 			final Hop second = constructPower(iterator.next());
-			first = HopRewriteUtils.createBinary(first, second, Hop.OpOp2.MULT);
+			first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
 		}
 		return first;
 	}
@@ -140,14 +141,75 @@ public class RewriteEMult extends HopRewriteRule {
 		final Hop hop = entry.getKey();
 		final int cnt = entry.getValue();
 		assert(cnt >= 1);
-		if (cnt == 1)
-			return hop;
+		if (cnt == 1) return hop;
 		return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
 	}
 
-	private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType)
-			.thenComparing(Hop::getName)
-			.thenComparingInt(Object::hashCode);
+
+
+	// Order: scalars > row vectors > col vectors >
+	//        non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+	//        other data types
+	// disambiguate by Hop ID
+	private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
+		@Override
+		public final int compare(Hop o1, Hop o2) {
+			int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
+			if (c != 0) return c;
+
+			// o1 and o2 have the same data type
+			switch (o1.getDataType()) {
+			case SCALAR: return Long.compare(o1.getHopID(), o2.getHopID());
+			case MATRIX:
+				// two matrices; check for vectors
+				if (o1.getDim1() == 1) { // row vector
+						if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
+						return compareBySparsityThenId(o1, o2); // both row vectors
+				} else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
+						return -1; // row vectors are the greatest matrices
+				} else if (o1.getDim2() == 1) { // col vector
+						if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
+						return compareBySparsityThenId(o1, o2); // both col vectors
+				} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
+						return 1; // col vectors greater than non-vectors
+				} else { // both non-vectors
+						return compareBySparsityThenId(o1, o2);
+				}
+			default:
+				return Long.compare(o1.getHopID(), o2.getHopID());
+			}
+		}
+		private int compareBySparsityThenId(Hop o1, Hop o2) {
+			// the hop with more nnz is first; unknown nnz (-1) last
+			int c = Long.compare(o1.getNnz(), o2.getNnz());
+			if (c != 0) return c;
+			return Long.compare(o1.getHopID(), o2.getHopID());
+		}
+		private final int[] orderDataType;
+		{
+			Expression.DataType[] dtValues = Expression.DataType.values();
+			orderDataType = new int[dtValues.length];
+			for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
+				switch(dtValues[i]) {
+				case SCALAR:
+					orderDataType[i] = 4;
+					break;
+				case MATRIX:
+					orderDataType[i] = 3;
+					break;
+				case FRAME:
+					orderDataType[i] = 2;
+					break;
+				case OBJECT:
+					orderDataType[i] = 1;
+					break;
+				case UNKNOWN:
+					orderDataType[i] = 0;
+					break;
+				}
+			}
+		}
+	};
 
 	private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
 		final ArrayList<Hop> parents = child.getParent();

http://git-wip-us.apache.org/repos/asf/systemml/blob/8b832f62/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
index e076c95..18ed55d 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
@@ -99,8 +99,9 @@ public class RewriteEMultChainTest extends AutomatedTestBase
 			fullRScriptName = HOME + testname + ".R";
 			rCmd = getRCmd(inputDir(), expectedDir());			
 
-			double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.97d, 7);
-			double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.9d, 3);
+			double Xsparsity = 0.8, Ysparsity = 0.6;
+			double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+			double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
 			writeInputMatrixWithMTD("X", X, true);
 			writeInputMatrixWithMTD("Y", Y, true);
 			
@@ -123,5 +124,5 @@ public class RewriteEMultChainTest extends AutomatedTestBase
 			rtplatform = platformOld;
 			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
 		}
-	}	
+	}
 }