You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:34 UTC
[08/23] systemml git commit: Document RewriteEMult. Add smart
recursion.
Document RewriteEMult. Add smart recursion.
RewriteEMult now rewrites emult chains deeper than the top-most one.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d88f867f
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d88f867f
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d88f867f
Branch: refs/heads/master
Commit: d88f867fd0384954dce9e6ce4d65f02f1054bc5e
Parents: a5846bb
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sat Jun 10 01:17:36 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:33 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/rewrite/HopRewriteUtils.java | 2 +
.../apache/sysml/hops/rewrite/RewriteEMult.java | 90 +++++++++++++-------
.../org/apache/sysml/parser/Expression.java | 1 -
3 files changed, 62 insertions(+), 31 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 4d23cb9..17ac4ec 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -246,6 +246,8 @@ public class HopRewriteUtils
* Replace an old Hop with a replacement Hop.
* If the old Hop has no parents, then return the replacement.
* Otherwise rewire each of the Hop's parents into the replacement and return the replacement.
+ * @param old To be replaced
+ * @param replacement The replacement
* @return replacement
*/
public static Hop replaceHop(final Hop old, final Hop replacement) {
http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index d483a08..5cd1471 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -42,6 +42,7 @@ import com.google.common.collect.Multiset;
*
* Rewrite a chain of element-wise multiply hops that contain identical elements.
* For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
+ * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
*
* Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
* since foreign parents may rely on the individual results.
@@ -74,18 +75,15 @@ public class RewriteEMult extends HopRewriteRule {
return root;
root.setVisited();
- final ArrayList<Hop> rootInputs = root.getInput();
-
// 1. Find immediate subtree of EMults.
if (isBinaryMult(root)) {
- final Hop left = rootInputs.get(0), right = rootInputs.get(1);
- final BinaryOp r = (BinaryOp)root;
+ final Hop left = root.getInput().get(0), right = root.getInput().get(1);
final Set<BinaryOp> emults = new HashSet<>();
final Multiset<Hop> leaves = HashMultiset.create();
- findEMultsAndLeaves(r, emults, leaves);
+ findEMultsAndLeaves((BinaryOp)root, emults, leaves);
// 2. Ensure it is profitable to do a rewrite.
- if (isOptimizable(leaves)) {
+ if (isOptimizable(emults, leaves)) {
// 3. Check for foreign parents.
// A foreign parent is a parent of some EMult that is not in the set.
// Foreign parents destroy correctness of this rewrite.
@@ -94,25 +92,35 @@ public class RewriteEMult extends HopRewriteRule {
if (okay) {
// 4. Construct replacement EMults for the leaves
final Hop replacement = constructReplacement(leaves);
- // 5. Replace root with replacement
if (LOG.isDebugEnabled())
LOG.debug(String.format(
"Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
emults.size(), root.getHopID(), replacement.getHopID()));
- replacement.setVisited();
- return HopRewriteUtils.replaceHop(root, replacement);
+
+ // 5. Replace root with replacement
+ final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
+
+ // 6. Recurse at leaves (no need to repeat the interior emults)
+ for (final Hop leaf : leaves.elementSet()) {
+ recurseInputs(leaf);
+ }
+ return newRoot;
}
}
}
-
// This rewrite is not applicable to the current root.
// Try the root's children.
- for (int i = 0; i < rootInputs.size(); i++) {
- final Hop input = rootInputs.get(i);
+ recurseInputs(root);
+ return root;
+ }
+
+ private static void recurseInputs(final Hop parent) {
+ final ArrayList<Hop> inputs = parent.getInput();
+ for (int i = 0; i < inputs.size(); i++) {
+ final Hop input = inputs.get(i);
final Hop newInput = rule_RewriteEMult(input);
- rootInputs.set(i, newInput);
+ inputs.set(i, newInput);
}
- return root;
}
private static Hop constructReplacement(final Multiset<Hop> leaves) {
@@ -133,6 +141,7 @@ public class RewriteEMult extends HopRewriteRule {
for (int i = 1; i < sorted.size(); i++) {
final Hop second = constructPower(iterator.next());
first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
+ first.setVisited();
}
return first;
}
@@ -141,16 +150,21 @@ public class RewriteEMult extends HopRewriteRule {
final Hop hop = entry.getKey();
final int cnt = entry.getValue();
assert(cnt >= 1);
- if (cnt == 1) return hop;
- return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+ if (cnt == 1)
+ return hop; // don't set this visited... we will visit this next
+ Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+ pow.setVisited();
+ return pow;
}
-
-
- // Order: scalars > row vectors > col vectors >
- // non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
- // other data types
- // disambiguate by Hop ID
+ /**
+ * A Comparator that orders Hops by their data type, dimention, and sparsity.
+ * The order is as follows:
+ * scalars > row vectors > col vectors >
+ * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+ * other data types.
+ * Disambiguate by Hop ID.
+ */
private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
@Override
public final int compare(Hop o1, Hop o2) {
@@ -211,6 +225,12 @@ public class RewriteEMult extends HopRewriteRule {
}
};
+ /**
+ * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
+ * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+ * @param child An interior emult hop in the emult chain dag.
+ * @return Whether this interior emult or any child emult has a foreign parent.
+ */
private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
final ArrayList<Hop> parents = child.getParent();
if (parents.size() > 1)
@@ -227,7 +247,10 @@ public class RewriteEMult extends HopRewriteRule {
}
/**
- * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root.
+ * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively.
+ * @param root Root of sub-dag
+ * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
+ * @param leaves Out parameter. The multiset of multiplicands in the emult chain.
*/
private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
@@ -243,12 +266,19 @@ public class RewriteEMult extends HopRewriteRule {
else leaves.add(right);
}
- /** Only optimize a subtree of EMults if at least one leaf occurs more than once. */
- private static boolean isOptimizable(final Multiset<Hop> set) {
- for (Multiset.Entry<Hop> hopEntry : set.entrySet()) {
- if (hopEntry.getCount() > 1)
- return true;
- }
- return false;
+ /**
+ * Only optimize a subtree of emults if there are at least two emults.
+ * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+ * @param leaves The multiset of multiplicands in the emult chain.
+ * @return If the multiset is worth optimizing.
+ */
+ private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+ // Old criterion: there should be at least one repeated leaf
+// for (Multiset.Entry<Hop> hopEntry : leaves.entrySet()) {
+// if (hopEntry.getCount() > 1)
+// return true;
+// }
+// return false;
+ return emults.size() >= 2;
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index b944e29..9ee3fba 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -162,7 +162,6 @@ public abstract class Expression
* Data types (matrix, scalar, frame, object, unknown).
*/
public enum DataType {
- // Careful: the order of these enums is significant! See RewriteEMult.comparatorByDataType
MATRIX, SCALAR, FRAME, OBJECT, UNKNOWN;
public boolean isMatrix() {