You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:46 UTC
[20/23] systemml git commit: Minor optimization inside rewrite rule.
Minor optimization inside rewrite rule.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e93c487e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e93c487e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e93c487e
Branch: refs/heads/master
Commit: e93c487ef1778934c94fac291c6e76651041c961
Parents: d18a4c8
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 23:04:21 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 23:04:21 2017 -0700
----------------------------------------------------------------------
...RewriteElementwiseMultChainOptimization.java | 38 ++++++++------------
1 file changed, 15 insertions(+), 23 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/e93c487e/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 1f85bbf..41fc61d 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -26,8 +26,8 @@ import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
-import java.util.SortedMap;
-import java.util.TreeMap;
+import java.util.SortedSet;
+import java.util.TreeSet;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
@@ -85,12 +85,15 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
// 1. Find immediate subtree of EMults. Check dimsKnown.
if (isBinaryMult(root) && root.dimsKnown()) {
final Hop left = root.getInput().get(0), right = root.getInput().get(1);
+ // The set of BinaryOp element-wise multiply hops in the emult chain.
final Set<BinaryOp> emults = new HashSet<>();
+ // The multiset of multiplicands in the emult chain.
final Map<Hop, Integer> leaves = new HashMap<>(); // poor man's HashMultiset
findEMultsAndLeaves((BinaryOp)root, emults, leaves);
// 2. Ensure it is profitable to do a rewrite.
- if (isOptimizable(emults, leaves)) {
+ // Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
+ if (emults.size() >= 2) {
// 3. Check for foreign parents.
// A foreign parent is a parent of some EMult that is not in the set.
// Foreign parents destroy correctness of this rewrite.
@@ -132,20 +135,20 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
private static Hop constructReplacement(final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
// Sort by data type
- final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
+ final SortedSet<Hop> sorted = new TreeSet<>(compareByDataType);
for (final Map.Entry<Hop, Integer> entry : leaves.entrySet()) {
final Hop h = entry.getKey();
// unlink parents that are in the emult set(we are throwing them away)
// keep other parents
h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent));
- sorted.put(h, entry.getValue());
+ sorted.add(constructPower(h, entry.getValue()));
}
// sorted contains all leaves, sorted by data type, stripped from their parents
// Construct right-deep EMult tree
- final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+ final Iterator<Hop> iterator = sorted.iterator();
- Hop next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ Hop next = iterator.hasNext() ? iterator.next() : null;
Hop colVectorsScalars = null;
while(next != null &&
(next.getDataType() == Expression.DataType.SCALAR
@@ -157,7 +160,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT);
colVectorsScalars.setVisited();
}
- next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ next = iterator.hasNext() ? iterator.next() : null;
}
// next is not processed and is either null or past col vectors
@@ -171,7 +174,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
rowVectors = HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT);
rowVectors.setVisited();
}
- next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ next = iterator.hasNext() ? iterator.next() : null;
}
// next is not processed and is either null or past row vectors
@@ -185,7 +188,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
matrices = HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT);
matrices.setVisited();
}
- next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ next = iterator.hasNext() ? iterator.next() : null;
}
// next is not processed and is either null or past matrices
@@ -198,7 +201,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
other = HopRewriteUtils.createBinary(other, next, Hop.OpOp2.MULT);
other.setVisited();
}
- next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ next = iterator.hasNext() ? iterator.next() : null;
}
// finished
@@ -230,9 +233,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
return top;
}
- private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
- final Hop hop = entry.getKey();
- final int cnt = entry.getValue();
+ private static Hop constructPower(final Hop hop, final int cnt) {
assert(cnt >= 1);
hop.setVisited(); // we will visit the leaves' children next
if (cnt == 1)
@@ -345,13 +346,4 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
map.put(k, map.getOrDefault(k, 0) + 1);
}
- /**
- * Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
- * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
- * @param leaves The multiset of multiplicands in the emult chain.
- * @return If the multiset is worth optimizing.
- */
- private static boolean isOptimizable(Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
- return emults.size() >= 2;
- }
}