You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:34 UTC

[08/23] systemml git commit: Document RewriteEMult. Add smart recursion.

Document RewriteEMult. Add smart recursion.

RewriteEMult now rewrites emult chains deeper than the top-most one.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d88f867f
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d88f867f
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d88f867f

Branch: refs/heads/master
Commit: d88f867fd0384954dce9e6ce4d65f02f1054bc5e
Parents: a5846bb
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sat Jun 10 01:17:36 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:33 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  2 +
 .../apache/sysml/hops/rewrite/RewriteEMult.java | 90 +++++++++++++-------
 .../org/apache/sysml/parser/Expression.java     |  1 -
 3 files changed, 62 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 4d23cb9..17ac4ec 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -246,6 +246,8 @@ public class HopRewriteUtils
 	 * Replace an old Hop with a replacement Hop.
 	 * If the old Hop has no parents, then return the replacement.
 	 * Otherwise rewire each of the Hop's parents into the replacement and return the replacement.
+	 * @param old To be replaced
+	 * @param replacement The replacement
 	 * @return replacement
 	 */
 	public static Hop replaceHop(final Hop old, final Hop replacement) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index d483a08..5cd1471 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -42,6 +42,7 @@ import com.google.common.collect.Multiset;
  *
  * Rewrite a chain of element-wise multiply hops that contain identical elements.
  * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
+ * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
  *
  * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
  * since foreign parents may rely on the individual results.
@@ -74,18 +75,15 @@ public class RewriteEMult extends HopRewriteRule {
 			return root;
 		root.setVisited();
 
-		final ArrayList<Hop> rootInputs = root.getInput();
-
 		// 1. Find immediate subtree of EMults.
 		if (isBinaryMult(root)) {
-			final Hop left = rootInputs.get(0), right = rootInputs.get(1);
-			final BinaryOp r = (BinaryOp)root;
+			final Hop left = root.getInput().get(0), right = root.getInput().get(1);
 			final Set<BinaryOp> emults = new HashSet<>();
 			final Multiset<Hop> leaves = HashMultiset.create();
-			findEMultsAndLeaves(r, emults, leaves);
+			findEMultsAndLeaves((BinaryOp)root, emults, leaves);
 
 			// 2. Ensure it is profitable to do a rewrite.
-			if (isOptimizable(leaves)) {
+			if (isOptimizable(emults, leaves)) {
 				// 3. Check for foreign parents.
 				// A foreign parent is a parent of some EMult that is not in the set.
 				// Foreign parents destroy correctness of this rewrite.
@@ -94,25 +92,35 @@ public class RewriteEMult extends HopRewriteRule {
 				if (okay) {
 					// 4. Construct replacement EMults for the leaves
 					final Hop replacement = constructReplacement(leaves);
-					// 5. Replace root with replacement
 					if (LOG.isDebugEnabled())
 						LOG.debug(String.format(
 								"Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
 								emults.size(), root.getHopID(), replacement.getHopID()));
-					replacement.setVisited();
-					return HopRewriteUtils.replaceHop(root, replacement);
+
+					// 5. Replace root with replacement
+					final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
+
+					// 6. Recurse at leaves (no need to repeat the interior emults)
+					for (final Hop leaf : leaves.elementSet()) {
+						recurseInputs(leaf);
+					}
+					return newRoot;
 				}
 			}
 		}
-
 		// This rewrite is not applicable to the current root.
 		// Try the root's children.
-		for (int i = 0; i < rootInputs.size(); i++) {
-			final Hop input = rootInputs.get(i);
+		recurseInputs(root);
+		return root;
+	}
+
+	private static void recurseInputs(final Hop parent) {
+		final ArrayList<Hop> inputs = parent.getInput();
+		for (int i = 0; i < inputs.size(); i++) {
+			final Hop input = inputs.get(i);
 			final Hop newInput = rule_RewriteEMult(input);
-			rootInputs.set(i, newInput);
+			inputs.set(i, newInput);
 		}
-		return root;
 	}
 
 	private static Hop constructReplacement(final Multiset<Hop> leaves) {
@@ -133,6 +141,7 @@ public class RewriteEMult extends HopRewriteRule {
 		for (int i = 1; i < sorted.size(); i++) {
 			final Hop second = constructPower(iterator.next());
 			first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
+			first.setVisited();
 		}
 		return first;
 	}
@@ -141,16 +150,21 @@ public class RewriteEMult extends HopRewriteRule {
 		final Hop hop = entry.getKey();
 		final int cnt = entry.getValue();
 		assert(cnt >= 1);
-		if (cnt == 1) return hop;
-		return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+		if (cnt == 1)
+			return hop; // don't set this visited... we will visit this next
+		Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+		pow.setVisited();
+		return pow;
 	}
 
-
-
-	// Order: scalars > row vectors > col vectors >
-	//        non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
-	//        other data types
-	// disambiguate by Hop ID
+	/**
+	 * A Comparator that orders Hops by their data type, dimention, and sparsity.
+	 * The order is as follows:
+	 * 		scalars > row vectors > col vectors >
+	 *      non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+	 *      other data types.
+	 * Disambiguate by Hop ID.
+	 */
 	private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
 		@Override
 		public final int compare(Hop o1, Hop o2) {
@@ -211,6 +225,12 @@ public class RewriteEMult extends HopRewriteRule {
 		}
 	};
 
+	/**
+	 * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
+	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+	 * @param child An interior emult hop in the emult chain dag.
+	 * @return Whether this interior emult or any child emult has a foreign parent.
+	 */
 	private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
 		final ArrayList<Hop> parents = child.getParent();
 		if (parents.size() > 1)
@@ -227,7 +247,10 @@ public class RewriteEMult extends HopRewriteRule {
 	}
 
 	/**
-	 * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root.
+	 * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively.
+	 * @param root Root of sub-dag
+	 * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
+	 * @param leaves Out parameter. The multiset of multiplicands in the emult chain.
 	 */
 	private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
 		// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
@@ -243,12 +266,19 @@ public class RewriteEMult extends HopRewriteRule {
 		else leaves.add(right);
 	}
 
-	/** Only optimize a subtree of EMults if at least one leaf occurs more than once. */
-	private static boolean isOptimizable(final Multiset<Hop> set) {
-		for (Multiset.Entry<Hop> hopEntry : set.entrySet()) {
-			if (hopEntry.getCount() > 1)
-				return true;
-		}
-		return false;
+	/**
+	 * Only optimize a subtree of emults if there are at least two emults.
+	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+	 * @param leaves The multiset of multiplicands in the emult chain.
+	 * @return If the multiset is worth optimizing.
+	 */
+	private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+		// Old criterion: there should be at least one repeated leaf
+//		for (Multiset.Entry<Hop> hopEntry : leaves.entrySet()) {
+//			if (hopEntry.getCount() > 1)
+//				return true;
+//		}
+//		return false;
+		return emults.size() >= 2;
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index b944e29..9ee3fba 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -162,7 +162,6 @@ public abstract class Expression
 	 * Data types (matrix, scalar, frame, object, unknown).
 	 */
 	public enum DataType {
-		// Careful: the order of these enums is significant! See RewriteEMult.comparatorByDataType
 		MATRIX, SCALAR, FRAME, OBJECT, UNKNOWN;
 		
 		public boolean isMatrix() {