You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:46 UTC

[20/23] systemml git commit: Minor optimization inside rewrite rule.

Minor optimization inside rewrite rule.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e93c487e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e93c487e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e93c487e

Branch: refs/heads/master
Commit: e93c487ef1778934c94fac291c6e76651041c961
Parents: d18a4c8
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 23:04:21 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 23:04:21 2017 -0700

----------------------------------------------------------------------
 ...RewriteElementwiseMultChainOptimization.java | 38 ++++++++------------
 1 file changed, 15 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/e93c487e/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 1f85bbf..41fc61d 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -26,8 +26,8 @@ import java.util.HashSet;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Set;
-import java.util.SortedMap;
-import java.util.TreeMap;
+import java.util.SortedSet;
+import java.util.TreeSet;
 
 import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
@@ -85,12 +85,15 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		// 1. Find immediate subtree of EMults. Check dimsKnown.
 		if (isBinaryMult(root) && root.dimsKnown()) {
 			final Hop left = root.getInput().get(0), right = root.getInput().get(1);
+			// The set of BinaryOp element-wise multiply hops in the emult chain.
 			final Set<BinaryOp> emults = new HashSet<>();
+			// The multiset of multiplicands in the emult chain.
 			final Map<Hop, Integer> leaves = new HashMap<>(); // poor man's HashMultiset
 			findEMultsAndLeaves((BinaryOp)root, emults, leaves);
 
 			// 2. Ensure it is profitable to do a rewrite.
-			if (isOptimizable(emults, leaves)) {
+			// Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
+			if (emults.size() >= 2) {
 				// 3. Check for foreign parents.
 				// A foreign parent is a parent of some EMult that is not in the set.
 				// Foreign parents destroy correctness of this rewrite.
@@ -132,20 +135,20 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 
 	private static Hop constructReplacement(final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
 		// Sort by data type
-		final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
+		final SortedSet<Hop> sorted = new TreeSet<>(compareByDataType);
 		for (final Map.Entry<Hop, Integer> entry : leaves.entrySet()) {
 			final Hop h = entry.getKey();
 			// unlink parents that are in the emult set(we are throwing them away)
 			// keep other parents
 			h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent));
-			sorted.put(h, entry.getValue());
+			sorted.add(constructPower(h, entry.getValue()));
 		}
 		// sorted contains all leaves, sorted by data type, stripped from their parents
 
 		// Construct right-deep EMult tree
-		final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+		final Iterator<Hop> iterator = sorted.iterator();
 
-		Hop next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+		Hop next = iterator.hasNext() ? iterator.next() : null;
 		Hop colVectorsScalars = null;
 		while(next != null &&
 				(next.getDataType() == Expression.DataType.SCALAR
@@ -157,7 +160,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 				colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT);
 				colVectorsScalars.setVisited();
 			}
-			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+			next = iterator.hasNext() ? iterator.next() : null;
 		}
 		// next is not processed and is either null or past col vectors
 
@@ -171,7 +174,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 				rowVectors = HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT);
 				rowVectors.setVisited();
 			}
-			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+			next = iterator.hasNext() ? iterator.next() : null;
 		}
 		// next is not processed and is either null or past row vectors
 
@@ -185,7 +188,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 				matrices = HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT);
 				matrices.setVisited();
 			}
-			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+			next = iterator.hasNext() ? iterator.next() : null;
 		}
 		// next is not processed and is either null or past matrices
 
@@ -198,7 +201,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 				other = HopRewriteUtils.createBinary(other, next, Hop.OpOp2.MULT);
 				other.setVisited();
 			}
-			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+			next = iterator.hasNext() ? iterator.next() : null;
 		}
 		// finished
 
@@ -230,9 +233,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		return top;
 	}
 
-	private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
-		final Hop hop = entry.getKey();
-		final int cnt = entry.getValue();
+	private static Hop constructPower(final Hop hop, final int cnt) {
 		assert(cnt >= 1);
 		hop.setVisited(); // we will visit the leaves' children next
 		if (cnt == 1)
@@ -345,13 +346,4 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		map.put(k, map.getOrDefault(k, 0) + 1);
 	}
 
-	/**
-	 * Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
-	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
-	 * @param leaves The multiset of multiplicands in the emult chain.
-	 * @return If the multiset is worth optimizing.
-	 */
-	private static boolean isOptimizable(Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
-		return emults.size() >= 2;
-	}
 }