You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:27 UTC
[01/23] systemml git commit: New rewrite rule for chains of
element-wise multiply.
Repository: systemml
Updated Branches:
refs/heads/master 1b3dff06b -> 85e3a9631
New rewrite rule for chains of element-wise multiply.
Placed rewrite rule after Common Subexpression Elimination.
Included helper method in HopRewriteUtils.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7d578838
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7d578838
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7d578838
Branch: refs/heads/master
Commit: 7d578838cc291a1adb6229bae01f7c9428b6f858
Parents: c434208
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Thu Jun 8 18:17:36 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:13 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/rewrite/HopRewriteUtils.java | 17 +-
.../sysml/hops/rewrite/ProgramRewriter.java | 1 +
.../apache/sysml/hops/rewrite/RewriteEMult.java | 186 +++++++++++++++++++
.../org/apache/sysml/parser/Expression.java | 1 +
4 files changed, 204 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index cf6081b..4d23cb9 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -241,7 +241,22 @@ public class HopRewriteUtils
parent.getInput().add( pos, child );
child.getParent().add( parent );
}
-
+
+ /**
+ * Replace an old Hop with a replacement Hop.
+ * If the old Hop has no parents, then return the replacement.
+ * Otherwise rewire each of the Hop's parents into the replacement and return the replacement.
+ * @return replacement
+ */
+ public static Hop replaceHop(final Hop old, final Hop replacement) {
+ final ArrayList<Hop> rootParents = old.getParent();
+ if (rootParents.isEmpty())
+ return replacement; // new old!
+ HopRewriteUtils.rewireAllParentChildReferences(old, replacement);
+ return replacement;
+ }
+
+
public static void rewireAllParentChildReferences( Hop hold, Hop hnew ) {
ArrayList<Hop> parents = new ArrayList<Hop>(hold.getParent());
for( Hop lparent : parents )
http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 0e65f3f..8573dd7 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,6 +96,7 @@ public class ProgramRewriter
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
+ _dagRuleSet.add( new RewriteEMult() ); //dependency: cse
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
new file mode 100644
index 0000000..47c32a9
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.LiteralOp;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+
+/**
+ * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule.
+ *
+ * Rewrite a chain of element-wise multiply hops that contain identical elements.
+ * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
+ *
+ * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
+ * since foreign parents may rely on the individual results.
+ */
+public class RewriteEMult extends HopRewriteRule {
+ @Override
+ public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
+ if( roots == null )
+ return null;
+
+ for( int i=0; i<roots.size(); i++ ) {
+ Hop h = roots.get(i);
+ roots.set(i, rule_RewriteEMult(h));
+ }
+ return roots;
+ }
+
+ @Override
+ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException {
+ if( root == null )
+ return null;
+ return rule_RewriteEMult(root);
+ }
+
+ private static boolean isBinaryMult(final Hop hop) {
+ return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT;
+ }
+
+ private static Hop rule_RewriteEMult(final Hop root) {
+ if (root.isVisited())
+ return root;
+ root.setVisited();
+
+ final ArrayList<Hop> rootInputs = root.getInput();
+
+ // 1. Find immediate subtree of EMults.
+ if (isBinaryMult(root)) {
+ final Hop left = rootInputs.get(0), right = rootInputs.get(1);
+ final BinaryOp r = (BinaryOp)root;
+ final Set<BinaryOp> emults = new HashSet<>();
+ final Multiset<Hop> leaves = HashMultiset.create();
+ findEMultsAndLeaves(r, emults, leaves);
+ // 2. Ensure it is profitable to do a rewrite.
+ if (isOptimizable(leaves)) {
+ // 3. Check for foreign parents.
+ // A foreign parent is a parent of some EMult that is not in the set.
+ // Foreign parents destroy correctness of this rewrite.
+ final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
+ (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
+ if (okay) {
+ // 4. Construct replacement EMults for the leaves
+ final Hop replacement = constructReplacement(leaves);
+
+ // 5. Replace root with replacement
+ return HopRewriteUtils.replaceHop(root, replacement);
+ }
+ }
+ }
+
+ // This rewrite is not applicable to the current root.
+ // Try the root's children.
+ for (int i = 0; i < rootInputs.size(); i++) {
+ final Hop input = rootInputs.get(i);
+ final Hop newInput = rule_RewriteEMult(input);
+ rootInputs.set(i, newInput);
+ }
+ return root;
+ }
+
+ private static Hop constructReplacement(final Multiset<Hop> leaves) {
+ // Sort by data type
+ final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
+ for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
+ final Hop h = entry.getElement();
+ // unlink parents (the EMults, which we are throwing away)
+ h.getParent().clear();
+ sorted.put(h, entry.getCount());
+ }
+ // sorted contains all leaves, sorted by data type, stripped from their parents
+
+ // Construct left-deep EMult tree
+ Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+ Hop first = constructPower(iterator.next());
+
+ for (int i = 1; i < sorted.size(); i++) {
+ final Hop second = constructPower(iterator.next());
+ first = HopRewriteUtils.createBinary(first, second, Hop.OpOp2.MULT);
+ }
+ return first;
+ }
+
+ private static Hop constructPower(Map.Entry<Hop, Integer> entry) {
+ final Hop hop = entry.getKey();
+ final int cnt = entry.getValue();
+ assert(cnt >= 1);
+ if (cnt == 1)
+ return hop;
+ return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+ }
+
+ private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType);
+
+ private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
+ final ArrayList<Hop> parents = child.getParent();
+ if (parents.size() > 1)
+ for (final Hop parent : parents)
+ //noinspection SuspiciousMethodCalls
+ if (!emults.contains(parent))
+ return false;
+ // child does not have foreign parents
+
+ final ArrayList<Hop> inputs = child.getInput();
+ final Hop left = inputs.get(0), right = inputs.get(1);
+ return (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
+ (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
+ }
+
+ /**
+ * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root.
+ */
+ private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+ // Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
+ emults.add(root);
+
+ final ArrayList<Hop> inputs = root.getInput();
+ final Hop left = inputs.get(0), right = inputs.get(1);
+
+ if (isBinaryMult(left)) findEMultsAndLeaves((BinaryOp) left, emults, leaves);
+ else leaves.add(left);
+
+ if (isBinaryMult(right)) findEMultsAndLeaves((BinaryOp) right, emults, leaves);
+ else leaves.add(right);
+ }
+
+ /** Only optimize a subtree of EMults if at least one leaf occurs more than once. */
+ private static boolean isOptimizable(final Multiset<Hop> set) {
+ for (Multiset.Entry<Hop> hopEntry : set.entrySet()) {
+ if (hopEntry.getCount() > 1)
+ return true;
+ }
+ return false;
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index 9ee3fba..b944e29 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -162,6 +162,7 @@ public abstract class Expression
* Data types (matrix, scalar, frame, object, unknown).
*/
public enum DataType {
+ // Careful: the order of these enums is significant! See RewriteEMult.comparatorByDataType
MATRIX, SCALAR, FRAME, OBJECT, UNKNOWN;
public boolean isMatrix() {
[12/23] systemml git commit: Fix visit status bug
Posted by mb...@apache.org.
Fix visit status bug
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0a8936cd
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0a8936cd
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0a8936cd
Branch: refs/heads/master
Commit: 0a8936cd849d74baced732f45f1c53812abce537
Parents: d6d3795
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 11 03:55:25 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:48 2017 -0700
----------------------------------------------------------------------
.../apache/sysml/hops/rewrite/HopDagValidator.java | 5 ++++-
.../RewriteElementwiseMultChainOptimization.java | 17 +++++++++--------
2 files changed, 13 insertions(+), 9 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/0a8936cd/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
index 8cb5e1e..9ac21fc 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
@@ -35,6 +35,8 @@ import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.utils.Explain;
+import com.google.common.collect.Lists;
+
import static org.apache.sysml.hops.HopsException.check;
/**
@@ -89,7 +91,8 @@ public class HopDagValidator {
//check visit status
final boolean seen = !state.seen.add(id);
check(seen == hop.isVisited(), hop,
- "seen previously is %b but does not match hop visit status", seen);
+ "(parents: %s) seen previously is %b but does not match hop visit status",
+ Lists.transform(hop.getParent(), Hop::getHopID), seen);
if (seen) return; // we saw the Hop previously, no need to re-validate
//check parent linking
http://git-wip-us.apache.org/repos/asf/systemml/blob/0a8936cd/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 91b7306..9ca0932 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -91,7 +91,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
(!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
if (okay) {
// 4. Construct replacement EMults for the leaves
- final Hop replacement = constructReplacement(leaves);
+ final Hop replacement = constructReplacement(emults, leaves);
if (LOG.isDebugEnabled())
LOG.debug(String.format(
"Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
@@ -123,13 +123,14 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
}
}
- private static Hop constructReplacement(final Multiset<Hop> leaves) {
+ private static Hop constructReplacement(final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
// Sort by data type
final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
final Hop h = entry.getElement();
- // unlink parents (the EMults, which we are throwing away)
- h.getParent().clear();
+ // unlink parents that are in the emult set(we are throwing them away)
+ // keep other parents
+ h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent));
sorted.put(h, entry.getCount());
}
// sorted contains all leaves, sorted by data type, stripped from their parents
@@ -146,12 +147,13 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
return first;
}
- private static Hop constructPower(Map.Entry<Hop, Integer> entry) {
+ private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
final Hop hop = entry.getKey();
final int cnt = entry.getValue();
assert(cnt >= 1);
+ hop.setVisited(); // we will visit the leaves' children next
if (cnt == 1)
- return hop; // don't set this visited... we will visit this next
+ return hop;
Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
pow.setVisited();
return pow;
@@ -222,8 +224,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
final ArrayList<Hop> parents = child.getParent();
if (parents.size() > 1)
for (final Hop parent : parents)
- //noinspection SuspiciousMethodCalls (for Intellij, which checks when
- if (!emults.contains(parent))
+ if (parent instanceof BinaryOp && !emults.contains(parent))
return false;
// child does not have foreign parents
[13/23] systemml git commit: Relax tolerance for
ElementwiseAdditionMultiplicationTest
Posted by mb...@apache.org.
Relax tolerance for ElementwiseAdditionMultiplicationTest
The new RewriteElementwiseMultChainOptimization reorders
`(A*B)*C` to `A*(B*C)`, which causes the result not to be exact.
Use epsilon of 1e-10.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/de469d23
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/de469d23
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/de469d23
Branch: refs/heads/master
Commit: de469d235e5fe06fd3e13a32262d6c357ffdcc81
Parents: 0a8936c
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 11 12:50:39 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:51 2017 -0700
----------------------------------------------------------------------
.../binary/matrix/ElementwiseAdditionMultiplicationTest.java | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/de469d23/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java
index 523a648..f78e598 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java
@@ -134,6 +134,6 @@ public class ElementwiseAdditionMultiplicationTest extends AutomatedTestBase
runTest();
- compareResults();
+ compareResults(1e-10);
}
}
[20/23] systemml git commit: Minor optimization inside rewrite rule.
Posted by mb...@apache.org.
Minor optimization inside rewrite rule.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e93c487e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e93c487e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e93c487e
Branch: refs/heads/master
Commit: e93c487ef1778934c94fac291c6e76651041c961
Parents: d18a4c8
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 23:04:21 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 23:04:21 2017 -0700
----------------------------------------------------------------------
...RewriteElementwiseMultChainOptimization.java | 38 ++++++++------------
1 file changed, 15 insertions(+), 23 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/e93c487e/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 1f85bbf..41fc61d 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -26,8 +26,8 @@ import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
-import java.util.SortedMap;
-import java.util.TreeMap;
+import java.util.SortedSet;
+import java.util.TreeSet;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
@@ -85,12 +85,15 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
// 1. Find immediate subtree of EMults. Check dimsKnown.
if (isBinaryMult(root) && root.dimsKnown()) {
final Hop left = root.getInput().get(0), right = root.getInput().get(1);
+ // The set of BinaryOp element-wise multiply hops in the emult chain.
final Set<BinaryOp> emults = new HashSet<>();
+ // The multiset of multiplicands in the emult chain.
final Map<Hop, Integer> leaves = new HashMap<>(); // poor man's HashMultiset
findEMultsAndLeaves((BinaryOp)root, emults, leaves);
// 2. Ensure it is profitable to do a rewrite.
- if (isOptimizable(emults, leaves)) {
+ // Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
+ if (emults.size() >= 2) {
// 3. Check for foreign parents.
// A foreign parent is a parent of some EMult that is not in the set.
// Foreign parents destroy correctness of this rewrite.
@@ -132,20 +135,20 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
private static Hop constructReplacement(final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
// Sort by data type
- final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
+ final SortedSet<Hop> sorted = new TreeSet<>(compareByDataType);
for (final Map.Entry<Hop, Integer> entry : leaves.entrySet()) {
final Hop h = entry.getKey();
// unlink parents that are in the emult set(we are throwing them away)
// keep other parents
h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent));
- sorted.put(h, entry.getValue());
+ sorted.add(constructPower(h, entry.getValue()));
}
// sorted contains all leaves, sorted by data type, stripped from their parents
// Construct right-deep EMult tree
- final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+ final Iterator<Hop> iterator = sorted.iterator();
- Hop next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ Hop next = iterator.hasNext() ? iterator.next() : null;
Hop colVectorsScalars = null;
while(next != null &&
(next.getDataType() == Expression.DataType.SCALAR
@@ -157,7 +160,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT);
colVectorsScalars.setVisited();
}
- next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ next = iterator.hasNext() ? iterator.next() : null;
}
// next is not processed and is either null or past col vectors
@@ -171,7 +174,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
rowVectors = HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT);
rowVectors.setVisited();
}
- next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ next = iterator.hasNext() ? iterator.next() : null;
}
// next is not processed and is either null or past row vectors
@@ -185,7 +188,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
matrices = HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT);
matrices.setVisited();
}
- next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ next = iterator.hasNext() ? iterator.next() : null;
}
// next is not processed and is either null or past matrices
@@ -198,7 +201,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
other = HopRewriteUtils.createBinary(other, next, Hop.OpOp2.MULT);
other.setVisited();
}
- next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ next = iterator.hasNext() ? iterator.next() : null;
}
// finished
@@ -230,9 +233,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
return top;
}
- private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
- final Hop hop = entry.getKey();
- final int cnt = entry.getValue();
+ private static Hop constructPower(final Hop hop, final int cnt) {
assert(cnt >= 1);
hop.setVisited(); // we will visit the leaves' children next
if (cnt == 1)
@@ -345,13 +346,4 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
map.put(k, map.getOrDefault(k, 0) + 1);
}
- /**
- * Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
- * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
- * @param leaves The multiset of multiplicands in the emult chain.
- * @return If the multiset is worth optimizing.
- */
- private static boolean isOptimizable(Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
- return emults.size() >= 2;
- }
}
[16/23] systemml git commit: Group together types of e-wise multiply
inputs. Comprehensive test.
Posted by mb...@apache.org.
Group together types of e-wise multiply inputs. Comprehensive test.
The test tests all different kinds of objects multiplied togehter.
The new order of element-wise multiply chains is as follows:
<pre>
(((unknown * object * frame) * ([least-nnz-matrix * matrix] * most-nnz-matrix))
* ([least-nnz-row-vector * row-vector] * most-nnz-row-vector))
* ([[scalars * least-nnz-col-vector] * col-vector] * most-nnz-col-vector)
</pre>
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/999fdfbc
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/999fdfbc
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/999fdfbc
Branch: refs/heads/master
Commit: 999fdfbca9ebd855e031e4b812b64f1b484a33d8
Parents: 6c3e1c5
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 22:10:17 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 22:10:17 2017 -0700
----------------------------------------------------------------------
...RewriteElementwiseMultChainOptimization.java | 122 ++++++++++++++---
...ElementwiseMultChainOptimizationAllTest.java | 134 +++++++++++++++++++
.../functions/misc/RewriteEMultChainOpAll.R | 37 +++++
.../functions/misc/RewriteEMultChainOpAll.dml | 31 +++++
4 files changed, 305 insertions(+), 19 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/999fdfbc/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 9cc8fcd..486072b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -46,6 +46,14 @@ import com.google.common.collect.Multiset;
*
* Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
* since foreign parents may rely on the individual results.
+ *
+ * The new order of element-wise multiply chains is as follows:
+ * <pre>
+ * (((unknown * object * frame) * ([least-nnz-matrix * matrix] * most-nnz-matrix))
+ * * ([least-nnz-row-vector * row-vector] * most-nnz-row-vector))
+ * * ([[scalars * least-nnz-col-vector] * col-vector] * most-nnz-col-vector)
+ * </pre>
+ * Identical elements are replaced with powers.
*/
public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
@Override
@@ -137,14 +145,90 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
// Construct right-deep EMult tree
final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
- Hop first = constructPower(iterator.next());
- for (int i = 1; i < sorted.size(); i++) {
- final Hop second = constructPower(iterator.next());
- first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
- first.setVisited();
+ Hop next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ Hop colVectorsScalars = null;
+ while(next != null &&
+ (next.getDataType() == Expression.DataType.SCALAR
+ || next.getDataType() == Expression.DataType.MATRIX && next.getDim2() == 1))
+ {
+ if( colVectorsScalars == null )
+ colVectorsScalars = next;
+ else {
+ colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT);
+ colVectorsScalars.setVisited();
+ }
+ next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ }
+ // next is not processed and is either null or past col vectors
+
+ Hop rowVectors = null;
+ while(next != null &&
+ (next.getDataType() == Expression.DataType.MATRIX && next.getDim1() == 1))
+ {
+ if( rowVectors == null )
+ rowVectors = next;
+ else {
+ rowVectors = HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT);
+ rowVectors.setVisited();
+ }
+ next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ }
+ // next is not processed and is either null or past row vectors
+
+ Hop matrices = null;
+ while(next != null &&
+ (next.getDataType() == Expression.DataType.MATRIX))
+ {
+ if( matrices == null )
+ matrices = next;
+ else {
+ matrices = HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT);
+ matrices.setVisited();
+ }
+ next = iterator.hasNext() ? constructPower(iterator.next()) : null;
}
- return first;
+ // next is not processed and is either null or past matrices
+
+ Hop other = null;
+ while(next != null)
+ {
+ if( other == null )
+ other = next;
+ else {
+ other = HopRewriteUtils.createBinary(other, next, Hop.OpOp2.MULT);
+ other.setVisited();
+ }
+ next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+ }
+ // finished
+
+ // ((other * matrices) * rowVectors) * colVectorsScalars
+ Hop top = null;
+ if( other == null && matrices != null )
+ top = matrices;
+ else if( other != null && matrices == null )
+ top = other;
+ else if( other != null ) { //matrices != null
+ top = HopRewriteUtils.createBinary(other, matrices, Hop.OpOp2.MULT);
+ top.setVisited();
+ }
+
+ if( top == null && rowVectors != null )
+ top = rowVectors;
+ else if( rowVectors != null ) { //top != null
+ top = HopRewriteUtils.createBinary(top, rowVectors, Hop.OpOp2.MULT);
+ top.setVisited();
+ }
+
+ if( top == null && colVectorsScalars != null )
+ top = colVectorsScalars;
+ else if( colVectorsScalars != null ) { //top != null
+ top = HopRewriteUtils.createBinary(top, colVectorsScalars, Hop.OpOp2.MULT);
+ top.setVisited();
+ }
+
+ return top;
}
private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
@@ -154,7 +238,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
hop.setVisited(); // we will visit the leaves' children next
if (cnt == 1)
return hop;
- Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+ final Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
pow.setVisited();
return pow;
}
@@ -162,8 +246,8 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
/**
* A Comparator that orders Hops by their data type, dimention, and sparsity.
* The order is as follows:
- * scalars > col vectors > row vectors >
- * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+ * scalars < col vectors < row vectors <
+ * non-vector matrices ordered by sparsity (higher nnz last, unknown sparsity last) >
* other data types.
* Disambiguate by Hop ID.
*/
@@ -172,11 +256,11 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
{
for (int i = 0, valuesLength = Expression.DataType.values().length; i < valuesLength; i++)
switch(Expression.DataType.values()[i]) {
- case SCALAR: orderDataType[i] = 4; break;
- case MATRIX: orderDataType[i] = 3; break;
+ case SCALAR: orderDataType[i] = 0; break;
+ case MATRIX: orderDataType[i] = 1; break;
case FRAME: orderDataType[i] = 2; break;
- case OBJECT: orderDataType[i] = 1; break;
- case UNKNOWN:orderDataType[i] = 0; break;
+ case OBJECT: orderDataType[i] = 3; break;
+ case UNKNOWN:orderDataType[i] = 4; break;
}
}
@@ -190,15 +274,15 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
case MATRIX:
// two matrices; check for vectors
if (o1.getDim2() == 1) { // col vector
- if (o2.getDim2() != 1) return 1; // col vectors are greatest of matrices
+ if (o2.getDim2() != 1) return -1; // col vectors are greatest of matrices
return compareBySparsityThenId(o1, o2); // both col vectors
} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
- return -1; // col vectors are the greatest matrices
+ return 1; // col vectors are the greatest matrices
} else if (o1.getDim1() == 1) { // row vector
- if (o2.getDim1() != 1) return 1; // row vectors greater than non-vectors
+ if (o2.getDim1() != 1) return -1; // row vectors greater than non-vectors
return compareBySparsityThenId(o1, o2); // both row vectors
} else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
- return 1; // col vectors greater than non-vectors
+ return 1; // row vectors greater than non-vectors
} else { // both non-vectors
return compareBySparsityThenId(o1, o2);
}
@@ -209,10 +293,10 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
private int compareBySparsityThenId(final Hop o1, final Hop o2) {
// the hop with more nnz is first; unknown nnz (-1) last
final int c = Long.compare(o1.getNnz(), o2.getNnz());
- if (c != 0) return c;
+ if (c != 0) return -c;
return Long.compare(o1.getHopID(), o2.getHopID());
}
- }.reversed();
+ };
/**
* Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
http://git-wip-us.apache.org/repos/asf/systemml/blob/999fdfbc/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
new file mode 100644
index 0000000..ba5c78d
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test rewriting `2*X*3*v*5*w*4*z*5*Y*2*v*2*X`, where `v` and `z` are row vectors and `w` is a column vector,
+ * successfully rewrites to `Y*(X^2)*z*(v^2)*w*2400`.
+ */
+public class RewriteElementwiseMultChainOptimizationAllTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME1 = "RewriteEMultChainOpAll";
+ private static final String TEST_DIR = "functions/misc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationAllTest.class.getSimpleName() + "/";
+
+ private static final int rows = 123;
+ private static final int cols = 321;
+ private static final double eps = Math.pow(10, -10);
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+ }
+
+ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
+ {
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( et ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+
+ try
+ {
+ TestConfiguration config = getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), input("v"), input("z"), input("w"), output("R") };
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ double Xsparsity = 0.8, Ysparsity = 0.6;
+ double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+ double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
+ double[][] z = getRandomMatrix(1, cols, -1, 1, Ysparsity, 5);
+ double[][] v = getRandomMatrix(1, cols, -1, 1, Xsparsity, 4);
+ double[][] w = getRandomMatrix(rows, 1, -1, 1, Ysparsity, 6);
+ writeInputMatrixWithMTD("X", X, true);
+ writeInputMatrixWithMTD("Y", Y, true);
+ writeInputMatrixWithMTD("z", z, true);
+ writeInputMatrixWithMTD("v", v, true);
+ writeInputMatrixWithMTD("w", w, true);
+
+ //execute tests
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
+ TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+
+ //check for presence of power operator, if we did a rewrite
+ if( rewrites ) {
+ Assert.assertTrue(heavyHittersContainsSubString("^2"));
+ }
+ }
+ finally {
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/999fdfbc/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
new file mode 100644
index 0000000..20f76c2
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+# args[1]=""
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
+v = as.matrix(readMM(paste(args[1], "v.mtx", sep="")))
+z = as.matrix(readMM(paste(args[1], "z.mtx", sep="")))
+w = as.matrix(readMM(paste(args[1], "w.mtx", sep="")))
+
+R = 2* X *3* X *5* Y *4*5*2*2* (matrix(1,length(w),1)%*%z) * (matrix(1,length(w),1)%*%v)^2 * (w%*%matrix(1,1,length(v)))
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/999fdfbc/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
new file mode 100644
index 0000000..90f9242
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1);
+Y = read($2);
+v = read($3);
+z = read($4);
+w = read($5);
+
+R = 2* X *3* v *5* w *4* z *5* Y *2* v *2* X
+
+write(R, $6);
\ No newline at end of file
[15/23] systemml git commit: Change order of row and col vectors,
so as to create inner products rather than outer products.
Posted by mb...@apache.org.
Change order of row and col vectors, so as to create inner products rather than outer products.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6c3e1c5b
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6c3e1c5b
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6c3e1c5b
Branch: refs/heads/master
Commit: 6c3e1c5bad30dc8f11ff9d3f412ce68873c37202
Parents: 04f692d
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 20:08:22 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 20:08:22 2017 -0700
----------------------------------------------------------------------
...RewriteElementwiseMultChainOptimization.java | 24 ++++++++++----------
1 file changed, 12 insertions(+), 12 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/6c3e1c5b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 9ca0932..9cc8fcd 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -162,7 +162,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
/**
* A Comparator that orders Hops by their data type, dimention, and sparsity.
* The order is as follows:
- * scalars > row vectors > col vectors >
+ * scalars > col vectors > row vectors >
* non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
* other data types.
* Disambiguate by Hop ID.
@@ -181,23 +181,23 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
}
@Override
- public final int compare(Hop o1, Hop o2) {
- int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
+ public final int compare(final Hop o1, final Hop o2) {
+ final int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
if (c != 0) return c;
// o1 and o2 have the same data type
switch (o1.getDataType()) {
case MATRIX:
// two matrices; check for vectors
- if (o1.getDim1() == 1) { // row vector
- if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
- return compareBySparsityThenId(o1, o2); // both row vectors
- } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
- return -1; // row vectors are the greatest matrices
- } else if (o1.getDim2() == 1) { // col vector
- if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
+ if (o1.getDim2() == 1) { // col vector
+ if (o2.getDim2() != 1) return 1; // col vectors are greatest of matrices
return compareBySparsityThenId(o1, o2); // both col vectors
} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
+ return -1; // col vectors are the greatest matrices
+ } else if (o1.getDim1() == 1) { // row vector
+ if (o2.getDim1() != 1) return 1; // row vectors greater than non-vectors
+ return compareBySparsityThenId(o1, o2); // both row vectors
+ } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
return 1; // col vectors greater than non-vectors
} else { // both non-vectors
return compareBySparsityThenId(o1, o2);
@@ -206,9 +206,9 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
return Long.compare(o1.getHopID(), o2.getHopID());
}
}
- private int compareBySparsityThenId(Hop o1, Hop o2) {
+ private int compareBySparsityThenId(final Hop o1, final Hop o2) {
// the hop with more nnz is first; unknown nnz (-1) last
- int c = Long.compare(o1.getNnz(), o2.getNnz());
+ final int c = Long.compare(o1.getNnz(), o2.getNnz());
if (c != 0) return c;
return Long.compare(o1.getHopID(), o2.getHopID());
}
[02/23] systemml git commit: Fix RewriteEMult comparator. Add tests.
Posted by mb...@apache.org.
Fix RewriteEMult comparator. Add tests.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/eb0599df
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/eb0599df
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/eb0599df
Branch: refs/heads/master
Commit: eb0599df4c3bcca15531b85a3d870a26e4653179
Parents: 7d57883
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 11:18:32 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:15 2017 -0700
----------------------------------------------------------------------
.../org/apache/sysml/hops/OptimizerUtils.java | 9 +-
.../sysml/hops/rewrite/ProgramRewriter.java | 3 +-
.../apache/sysml/hops/rewrite/RewriteEMult.java | 10 +-
.../functions/misc/RewriteEMultChainTest.java | 127 +++++++++
.../ternary/ABATernaryAggregateTest.java | 268 +++++++++++++++++++
.../functions/misc/RewriteEMultChainOp.R | 33 +++
.../functions/misc/RewriteEMultChainOp.dml | 28 ++
.../functions/ternary/ABATernaryAggregateC.R | 32 +++
.../functions/ternary/ABATernaryAggregateC.dml | 30 +++
.../functions/ternary/ABATernaryAggregateRC.R | 33 +++
.../functions/ternary/ABATernaryAggregateRC.dml | 30 +++
11 files changed, 597 insertions(+), 6 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
index a40e36c..2a76d07 100644
--- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
@@ -110,8 +110,13 @@ public class OptimizerUtils
*/
public static boolean ALLOW_CONSTANT_FOLDING = true;
- public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true;
- public static boolean ALLOW_OPERATOR_FUSION = true;
+ public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true;
+ /**
+ * Enables rewriting chains of element-wise multiplies that contain the same multiplicand more than once, as in
+ * `A*B*A ==> (A^2)*B`.
+ */
+ public static boolean ALLOW_EMULT_CHAIN_REWRITE = true;
+ public static boolean ALLOW_OPERATOR_FUSION = true;
/**
* Enables if-else branch removal for constant predicates (original literals or
http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 8573dd7..b6aab38 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,7 +96,8 @@ public class ProgramRewriter
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
- _dagRuleSet.add( new RewriteEMult() ); //dependency: cse
+ if ( OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE )
+ _dagRuleSet.add( new RewriteEMult() ); //dependency: cse
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index 47c32a9..2c9e5cb 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -50,7 +50,6 @@ public class RewriteEMult extends HopRewriteRule {
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
if( roots == null )
return null;
-
for( int i=0; i<roots.size(); i++ ) {
Hop h = roots.get(i);
roots.set(i, rule_RewriteEMult(h));
@@ -83,6 +82,7 @@ public class RewriteEMult extends HopRewriteRule {
final Set<BinaryOp> emults = new HashSet<>();
final Multiset<Hop> leaves = HashMultiset.create();
findEMultsAndLeaves(r, emults, leaves);
+
// 2. Ensure it is profitable to do a rewrite.
if (isOptimizable(leaves)) {
// 3. Check for foreign parents.
@@ -93,8 +93,12 @@ public class RewriteEMult extends HopRewriteRule {
if (okay) {
// 4. Construct replacement EMults for the leaves
final Hop replacement = constructReplacement(leaves);
-
// 5. Replace root with replacement
+ if (LOG.isDebugEnabled())
+ LOG.debug(String.format(
+ "Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
+ emults.size(), root.getHopID(), replacement.getHopID()));
+ replacement.setVisited();
return HopRewriteUtils.replaceHop(root, replacement);
}
}
@@ -141,7 +145,7 @@ public class RewriteEMult extends HopRewriteRule {
return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
}
- private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType);
+ private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType).thenComparing(Object::hashCode);
private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
final ArrayList<Hop> parents = child.getParent();
http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
new file mode 100644
index 0000000..e076c95
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
+ */
+public class RewriteEMultChainTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME1 = "RewriteEMultChainOp";
+ private static final String TEST_DIR = "functions/misc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEMultChainTest.class.getSimpleName() + "/";
+
+ private static final int rows = 123;
+ private static final int cols = 321;
+ private static final double eps = Math.pow(10, -10);
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+ }
+
+ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
+ {
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( et ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ boolean rewritesOld = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
+ OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
+
+ try
+ {
+ TestConfiguration config = getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[]{ "-explain", "hops", "-stats",
+ "-args", input("X"), input("Y"), output("R") };
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.97d, 7);
+ double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.9d, 3);
+ writeInputMatrixWithMTD("X", X, true);
+ writeInputMatrixWithMTD("Y", Y, true);
+
+ //execute tests
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
+ TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+
+ //check for presence of power operator, if we did a rewrite
+ if( rewrites ) {
+ Assert.assertTrue(heavyHittersContainsSubString("^2"));
+ }
+ }
+ finally {
+ OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOld;
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
new file mode 100644
index 0000000..198e9f4
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.ternary;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.instructions.Instruction;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Similar to {@link TernaryAggregateTest} except that it tests `sum(A*B*A)`.
+ * Checks compatibility with {@link org.apache.sysml.hops.rewrite.RewriteEMult}.
+ */
+public class ABATernaryAggregateTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME1 = "ABATernaryAggregateRC";
+ private final static String TEST_NAME2 = "ABATernaryAggregateC";
+
+ private final static String TEST_DIR = "functions/ternary/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + ABATernaryAggregateTest.class.getSimpleName() + "/";
+ private final static double eps = 1e-8;
+
+ private final static int rows = 1111;
+ private final static int cols = 1011;
+
+ private final static double sparsity1 = 0.7;
+ private final static double sparsity2 = 0.3;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+ addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+ }
+
+ @Test
+ public void testTernaryAggregateRCDenseVectorCP() {
+ runTernaryAggregateTest(TEST_NAME1, false, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseVectorCP() {
+ runTernaryAggregateTest(TEST_NAME1, true, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCDenseMatrixCP() {
+ runTernaryAggregateTest(TEST_NAME1, false, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseMatrixCP() {
+ runTernaryAggregateTest(TEST_NAME1, true, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCDenseVectorSP() {
+ runTernaryAggregateTest(TEST_NAME1, false, true, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseVectorSP() {
+ runTernaryAggregateTest(TEST_NAME1, true, true, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateRCDenseMatrixSP() {
+ runTernaryAggregateTest(TEST_NAME1, false, false, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseMatrixSP() {
+ runTernaryAggregateTest(TEST_NAME1, true, false, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateRCDenseVectorMR() {
+ runTernaryAggregateTest(TEST_NAME1, false, true, true, ExecType.MR);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseVectorMR() {
+ runTernaryAggregateTest(TEST_NAME1, true, true, true, ExecType.MR);
+ }
+
+ @Test
+ public void testTernaryAggregateRCDenseMatrixMR() {
+ runTernaryAggregateTest(TEST_NAME1, false, false, true, ExecType.MR);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseMatrixMR() {
+ runTernaryAggregateTest(TEST_NAME1, true, false, true, ExecType.MR);
+ }
+
+ @Test
+ public void testTernaryAggregateCDenseVectorCP() {
+ runTernaryAggregateTest(TEST_NAME2, false, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCSparseVectorCP() {
+ runTernaryAggregateTest(TEST_NAME2, true, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCDenseMatrixCP() {
+ runTernaryAggregateTest(TEST_NAME2, false, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCSparseMatrixCP() {
+ runTernaryAggregateTest(TEST_NAME2, true, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCDenseVectorSP() {
+ runTernaryAggregateTest(TEST_NAME2, false, true, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateCSparseVectorSP() {
+ runTernaryAggregateTest(TEST_NAME2, true, true, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateCDenseMatrixSP() {
+ runTernaryAggregateTest(TEST_NAME2, false, false, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateCSparseMatrixSP() {
+ runTernaryAggregateTest(TEST_NAME2, true, false, true, ExecType.SPARK);
+ }
+
+ //additional tests to check default without rewrites
+
+ @Test
+ public void testTernaryAggregateRCDenseVectorCPNoRewrite() {
+ runTernaryAggregateTest(TEST_NAME1, false, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseVectorCPNoRewrite() {
+ runTernaryAggregateTest(TEST_NAME1, true, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCDenseMatrixCPNoRewrite() {
+ runTernaryAggregateTest(TEST_NAME1, false, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseMatrixCPNoRewrite() {
+ runTernaryAggregateTest(TEST_NAME1, true, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCDenseVectorCPNoRewrite() {
+ runTernaryAggregateTest(TEST_NAME2, false, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCSparseVectorCPNoRewrite() {
+ runTernaryAggregateTest(TEST_NAME2, true, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCDenseMatrixCPNoRewrite() {
+ runTernaryAggregateTest(TEST_NAME2, false, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCSparseMatrixCPNoRewrite() {
+ runTernaryAggregateTest(TEST_NAME2, true, false, false, ExecType.CP);
+ }
+
+
+
+ private void runTernaryAggregateTest(String testname, boolean sparse, boolean vectors, boolean rewrites, ExecType et)
+ {
+ //rtplatform for MR
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( et ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES,
+ rewritesOldEmult = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
+
+ try {
+ TestConfiguration config = getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+ OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[]{"-explain","hops","-stats","-args", input("A"), output("R")};
+
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = "Rscript" + " " + fullRScriptName + " " +
+ inputDir() + " " + expectedDir();
+
+ //generate actual dataset
+ double sparsity = sparse ? sparsity2 : sparsity1;
+ double[][] A = getRandomMatrix(vectors ? rows*cols : rows,
+ vectors ? 1 : cols, 0, 1, sparsity, 17);
+ writeInputMatrixWithMTD("A", A, true);
+
+ //run test cases
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare output matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
+ TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+
+ //check for rewritten patterns in statistics output
+ if( rewrites && et != ExecType.MR ) {
+ String opcode = ((et == ExecType.SPARK) ? Instruction.SP_INST_PREFIX : "") +
+ (((testname.equals(TEST_NAME1) || vectors ) ? "tak+*" : "tack+*"));
+ Assert.assertTrue(Statistics.getCPHeavyHitterOpCodes().contains(opcode));
+ }
+ }
+ finally {
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+ OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOldEmult;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/misc/RewriteEMultChainOp.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.R b/src/test/scripts/functions/misc/RewriteEMultChainOp.R
new file mode 100644
index 0000000..6d94cc8
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOp.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
+
+R = X * Y * X;
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
new file mode 100644
index 0000000..3992403
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1);
+Y = read($2);
+
+R = X * Y * X;
+
+write(R, $3);
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateC.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateC.R b/src/test/scripts/functions/ternary/ABATernaryAggregateC.R
new file mode 100644
index 0000000..9601089
--- /dev/null
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateC.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = A * 2;
+
+R = t(as.matrix(colSums(A * B * A)));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
new file mode 100644
index 0000000..78285af
--- /dev/null
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = A * 2;
+C = A * 3;
+
+if(1==1){}
+
+R = colSums(A * B * A);
+
+write(R, $2);
http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R
new file mode 100644
index 0000000..6552c7e
--- /dev/null
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = A * 2;
+
+s = sum(A * B * A);
+R = as.matrix(s);
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml
new file mode 100644
index 0000000..965c8d3
--- /dev/null
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = A * 2;
+
+if(1==1){}
+
+s = sum(A * B * A);
+R = as.matrix(s);
+
+write(R, $2);
\ No newline at end of file
[09/23] systemml git commit: Review comments, part 1
Posted by mb...@apache.org.
Review comments, part 1
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b94557fd
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b94557fd
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b94557fd
Branch: refs/heads/master
Commit: b94557fd2c90c591179cdbf05a32242fadc36448
Parents: d88f867
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 11 00:35:52 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:37 2017 -0700
----------------------------------------------------------------------
.../java/org/apache/sysml/hops/AggUnaryOp.java | 21 +-
.../org/apache/sysml/hops/OptimizerUtils.java | 5 -
.../sysml/hops/rewrite/HopRewriteUtils.java | 3 +-
.../sysml/hops/rewrite/ProgramRewriter.java | 4 +-
.../RewriteAlgebraicSimplificationDynamic.java | 16 +-
.../apache/sysml/hops/rewrite/RewriteEMult.java | 284 -------------------
...RewriteElementwiseMultChainOptimization.java | 281 ++++++++++++++++++
.../functions/misc/RewriteEMultChainTest.java | 127 ---------
...ementwiseMultChainOptimizationChainTest.java | 127 +++++++++
.../ternary/ABATernaryAggregateTest.java | 9 +-
.../functions/misc/ZPackageSuite.java | 1 +
.../functions/ternary/ZPackageSuite.java | 3 +-
12 files changed, 436 insertions(+), 445 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 300a20c..a207831 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -497,7 +497,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
if (binput1.getOp() == OpOp2.POW
&& binput1.getInput().get(1) instanceof LiteralOp) {
LiteralOp lit = (LiteralOp)binput1.getInput().get(1);
- ret = lit.getLongValue() == 3;
+ ret = HopRewriteUtils.getIntValueSafe(lit) == 3;
}
else if (binput1.getOp() == OpOp2.MULT
// As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it.
@@ -640,15 +640,10 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
boolean handled = false;
if (input1.getOp() == OpOp2.POW) {
- switch ((int)((LiteralOp)input12).getLongValue()) {
- case 3:
- in1 = input11.constructLops();
- in2 = in1;
- in3 = in1;
- break;
- default:
- throw new AssertionError("unreachable; only applies to power 3");
- }
+ assert(HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with a power of 3";
+ in1 = input11.constructLops();
+ in2 = in1;
+ in3 = in1;
handled = true;
} else if (input11 instanceof BinaryOp ) {
BinaryOp b11 = (BinaryOp)input11;
@@ -662,8 +657,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
case POW: // A*A*B case
Hop b112 = b11.getInput().get(1);
if ( !(input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT)
- && b112 instanceof LiteralOp
- && ((LiteralOp)b112).getLongValue() == 2) {
+ && HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
in1 = b11.getInput().get(0).constructLops();
in2 = in1;
in3 = input12.constructLops();
@@ -682,8 +676,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
break;
case POW: // A*B*B case
Hop b112 = b12.getInput().get(1);
- if ( b112 instanceof LiteralOp
- && ((LiteralOp)b112).getLongValue() == 2) {
+ if ( HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
in1 = b12.getInput().get(0).constructLops();
in2 = in1;
in3 = input11.constructLops();
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
index 2a76d07..79b7ee6 100644
--- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
@@ -111,11 +111,6 @@ public class OptimizerUtils
public static boolean ALLOW_CONSTANT_FOLDING = true;
public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true;
- /**
- * Enables rewriting chains of element-wise multiplies that contain the same multiplicand more than once, as in
- * `A*B*A ==> (A^2)*B`.
- */
- public static boolean ALLOW_EMULT_CHAIN_REWRITE = true;
public static boolean ALLOW_OPERATOR_FUSION = true;
/**
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 17ac4ec..8f71359 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -251,8 +251,7 @@ public class HopRewriteUtils
* @return replacement
*/
public static Hop replaceHop(final Hop old, final Hop replacement) {
- final ArrayList<Hop> rootParents = old.getParent();
- if (rootParents.isEmpty())
+ if (old.getParent().isEmpty())
return replacement; // new old!
HopRewriteUtils.rewireAllParentChildReferences(old, replacement);
return replacement;
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index b6aab38..1053850 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,8 +96,8 @@ public class ProgramRewriter
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
- if ( OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE )
- _dagRuleSet.add( new RewriteEMult() ); //dependency: cse
+ if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
+ _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 166af2f..91c5972 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -53,6 +53,8 @@ import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
+import static org.apache.sysml.hops.OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+
/**
* Rule: Algebraic Simplifications. Simplifies binary expressions
* in terms of two major purposes: (1) rewrite binary operations
@@ -2051,12 +2053,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
&& hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1
&& !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT)
&& !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT)
- && !(HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) // do not rewrite (A^2)*B
- && hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp // let tak+* handle it
- && ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)
- && !(HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) // do not rewrite B*(A^2)
- && hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp // let tak+* handle it
- && ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2)
+ && ( !ALLOW_SUM_PRODUCT_REWRITES
+ || !( HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) // do not rewrite (A^2)*B
+ && hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp // let tak+* handle it
+ && ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2 ))
+ && ( !ALLOW_SUM_PRODUCT_REWRITES
+ || !( HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) // do not rewrite B*(A^2)
+ && hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp // let tak+* handle it
+ && ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2 ))
)
{
baLeft = hi2.getInput().get(0);
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
deleted file mode 100644
index 5cd1471..0000000
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ /dev/null
@@ -1,284 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.hops.rewrite;
-
-import java.util.ArrayList;
-import java.util.Comparator;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.Set;
-import java.util.SortedMap;
-import java.util.TreeMap;
-
-import org.apache.sysml.hops.BinaryOp;
-import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.HopsException;
-import org.apache.sysml.hops.LiteralOp;
-import org.apache.sysml.parser.Expression;
-
-import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Multiset;
-
-/**
- * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule.
- *
- * Rewrite a chain of element-wise multiply hops that contain identical elements.
- * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
- * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
- *
- * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
- * since foreign parents may rely on the individual results.
- */
-public class RewriteEMult extends HopRewriteRule {
- @Override
- public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
- if( roots == null )
- return null;
- for( int i=0; i<roots.size(); i++ ) {
- Hop h = roots.get(i);
- roots.set(i, rule_RewriteEMult(h));
- }
- return roots;
- }
-
- @Override
- public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException {
- if( root == null )
- return null;
- return rule_RewriteEMult(root);
- }
-
- private static boolean isBinaryMult(final Hop hop) {
- return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT;
- }
-
- private static Hop rule_RewriteEMult(final Hop root) {
- if (root.isVisited())
- return root;
- root.setVisited();
-
- // 1. Find immediate subtree of EMults.
- if (isBinaryMult(root)) {
- final Hop left = root.getInput().get(0), right = root.getInput().get(1);
- final Set<BinaryOp> emults = new HashSet<>();
- final Multiset<Hop> leaves = HashMultiset.create();
- findEMultsAndLeaves((BinaryOp)root, emults, leaves);
-
- // 2. Ensure it is profitable to do a rewrite.
- if (isOptimizable(emults, leaves)) {
- // 3. Check for foreign parents.
- // A foreign parent is a parent of some EMult that is not in the set.
- // Foreign parents destroy correctness of this rewrite.
- final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
- (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
- if (okay) {
- // 4. Construct replacement EMults for the leaves
- final Hop replacement = constructReplacement(leaves);
- if (LOG.isDebugEnabled())
- LOG.debug(String.format(
- "Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
- emults.size(), root.getHopID(), replacement.getHopID()));
-
- // 5. Replace root with replacement
- final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
-
- // 6. Recurse at leaves (no need to repeat the interior emults)
- for (final Hop leaf : leaves.elementSet()) {
- recurseInputs(leaf);
- }
- return newRoot;
- }
- }
- }
- // This rewrite is not applicable to the current root.
- // Try the root's children.
- recurseInputs(root);
- return root;
- }
-
- private static void recurseInputs(final Hop parent) {
- final ArrayList<Hop> inputs = parent.getInput();
- for (int i = 0; i < inputs.size(); i++) {
- final Hop input = inputs.get(i);
- final Hop newInput = rule_RewriteEMult(input);
- inputs.set(i, newInput);
- }
- }
-
- private static Hop constructReplacement(final Multiset<Hop> leaves) {
- // Sort by data type
- final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
- for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
- final Hop h = entry.getElement();
- // unlink parents (the EMults, which we are throwing away)
- h.getParent().clear();
- sorted.put(h, entry.getCount());
- }
- // sorted contains all leaves, sorted by data type, stripped from their parents
-
- // Construct right-deep EMult tree
- final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
- Hop first = constructPower(iterator.next());
-
- for (int i = 1; i < sorted.size(); i++) {
- final Hop second = constructPower(iterator.next());
- first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
- first.setVisited();
- }
- return first;
- }
-
- private static Hop constructPower(Map.Entry<Hop, Integer> entry) {
- final Hop hop = entry.getKey();
- final int cnt = entry.getValue();
- assert(cnt >= 1);
- if (cnt == 1)
- return hop; // don't set this visited... we will visit this next
- Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
- pow.setVisited();
- return pow;
- }
-
- /**
- * A Comparator that orders Hops by their data type, dimention, and sparsity.
- * The order is as follows:
- * scalars > row vectors > col vectors >
- * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
- * other data types.
- * Disambiguate by Hop ID.
- */
- private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
- @Override
- public final int compare(Hop o1, Hop o2) {
- int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
- if (c != 0) return c;
-
- // o1 and o2 have the same data type
- switch (o1.getDataType()) {
- case SCALAR: return Long.compare(o1.getHopID(), o2.getHopID());
- case MATRIX:
- // two matrices; check for vectors
- if (o1.getDim1() == 1) { // row vector
- if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
- return compareBySparsityThenId(o1, o2); // both row vectors
- } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
- return -1; // row vectors are the greatest matrices
- } else if (o1.getDim2() == 1) { // col vector
- if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
- return compareBySparsityThenId(o1, o2); // both col vectors
- } else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
- return 1; // col vectors greater than non-vectors
- } else { // both non-vectors
- return compareBySparsityThenId(o1, o2);
- }
- default:
- return Long.compare(o1.getHopID(), o2.getHopID());
- }
- }
- private int compareBySparsityThenId(Hop o1, Hop o2) {
- // the hop with more nnz is first; unknown nnz (-1) last
- int c = Long.compare(o1.getNnz(), o2.getNnz());
- if (c != 0) return c;
- return Long.compare(o1.getHopID(), o2.getHopID());
- }
- private final int[] orderDataType;
- {
- Expression.DataType[] dtValues = Expression.DataType.values();
- orderDataType = new int[dtValues.length];
- for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
- switch(dtValues[i]) {
- case SCALAR:
- orderDataType[i] = 4;
- break;
- case MATRIX:
- orderDataType[i] = 3;
- break;
- case FRAME:
- orderDataType[i] = 2;
- break;
- case OBJECT:
- orderDataType[i] = 1;
- break;
- case UNKNOWN:
- orderDataType[i] = 0;
- break;
- }
- }
- }
- };
-
- /**
- * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
- * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
- * @param child An interior emult hop in the emult chain dag.
- * @return Whether this interior emult or any child emult has a foreign parent.
- */
- private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
- final ArrayList<Hop> parents = child.getParent();
- if (parents.size() > 1)
- for (final Hop parent : parents)
- //noinspection SuspiciousMethodCalls
- if (!emults.contains(parent))
- return false;
- // child does not have foreign parents
-
- final ArrayList<Hop> inputs = child.getInput();
- final Hop left = inputs.get(0), right = inputs.get(1);
- return (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
- (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
- }
-
- /**
- * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively.
- * @param root Root of sub-dag
- * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
- * @param leaves Out parameter. The multiset of multiplicands in the emult chain.
- */
- private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
- // Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
- emults.add(root);
-
- final ArrayList<Hop> inputs = root.getInput();
- final Hop left = inputs.get(0), right = inputs.get(1);
-
- if (isBinaryMult(left)) findEMultsAndLeaves((BinaryOp) left, emults, leaves);
- else leaves.add(left);
-
- if (isBinaryMult(right)) findEMultsAndLeaves((BinaryOp) right, emults, leaves);
- else leaves.add(right);
- }
-
- /**
- * Only optimize a subtree of emults if there are at least two emults.
- * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
- * @param leaves The multiset of multiplicands in the emult chain.
- * @return If the multiset is worth optimizing.
- */
- private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
- // Old criterion: there should be at least one repeated leaf
-// for (Multiset.Entry<Hop> hopEntry : leaves.entrySet()) {
-// if (hopEntry.getCount() > 1)
-// return true;
-// }
-// return false;
- return emults.size() >= 2;
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
new file mode 100644
index 0000000..bd873ff
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -0,0 +1,281 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.parser.Expression;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+
+/**
+ * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule.
+ *
+ * Rewrite a chain of element-wise multiply hops that contain identical elements.
+ * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
+ * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
+ *
+ * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
+ * since foreign parents may rely on the individual results.
+ */
+public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
+ @Override
+ public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
+ if( roots == null )
+ return null;
+ for( int i=0; i<roots.size(); i++ ) {
+ Hop h = roots.get(i);
+ roots.set(i, rule_RewriteEMult(h));
+ }
+ return roots;
+ }
+
+ @Override
+ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException {
+ if( root == null )
+ return null;
+ return rule_RewriteEMult(root);
+ }
+
+ private static boolean isBinaryMult(final Hop hop) {
+ return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT;
+ }
+
+ private static Hop rule_RewriteEMult(final Hop root) {
+ if (root.isVisited())
+ return root;
+ root.setVisited();
+
+ // 1. Find immediate subtree of EMults.
+ if (isBinaryMult(root)) {
+ final Hop left = root.getInput().get(0), right = root.getInput().get(1);
+ final Set<BinaryOp> emults = new HashSet<>();
+ final Multiset<Hop> leaves = HashMultiset.create();
+ findEMultsAndLeaves((BinaryOp)root, emults, leaves);
+
+ // 2. Ensure it is profitable to do a rewrite.
+ if (isOptimizable(emults, leaves)) {
+ // 3. Check for foreign parents.
+ // A foreign parent is a parent of some EMult that is not in the set.
+ // Foreign parents destroy correctness of this rewrite.
+ final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
+ (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
+ if (okay) {
+ // 4. Construct replacement EMults for the leaves
+ final Hop replacement = constructReplacement(leaves);
+ if (LOG.isDebugEnabled())
+ LOG.debug(String.format(
+ "Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
+ emults.size(), root.getHopID(), replacement.getHopID()));
+
+ // 5. Replace root with replacement
+ final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
+
+ // 6. Recurse at leaves (no need to repeat the interior emults)
+ for (final Hop leaf : leaves.elementSet()) {
+ recurseInputs(leaf);
+ }
+ return newRoot;
+ }
+ }
+ }
+ // This rewrite is not applicable to the current root.
+ // Try the root's children.
+ recurseInputs(root);
+ return root;
+ }
+
+ private static void recurseInputs(final Hop parent) {
+ final ArrayList<Hop> inputs = parent.getInput();
+ for (int i = 0; i < inputs.size(); i++) {
+ final Hop input = inputs.get(i);
+ final Hop newInput = rule_RewriteEMult(input);
+ inputs.set(i, newInput);
+ }
+ }
+
+ private static Hop constructReplacement(final Multiset<Hop> leaves) {
+ // Sort by data type
+ final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
+ for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
+ final Hop h = entry.getElement();
+ // unlink parents (the EMults, which we are throwing away)
+ h.getParent().clear();
+ sorted.put(h, entry.getCount());
+ }
+ // sorted contains all leaves, sorted by data type, stripped from their parents
+
+ // Construct right-deep EMult tree
+ final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+ Hop first = constructPower(iterator.next());
+
+ for (int i = 1; i < sorted.size(); i++) {
+ final Hop second = constructPower(iterator.next());
+ first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
+ first.setVisited();
+ }
+ return first;
+ }
+
+ private static Hop constructPower(Map.Entry<Hop, Integer> entry) {
+ final Hop hop = entry.getKey();
+ final int cnt = entry.getValue();
+ assert(cnt >= 1);
+ if (cnt == 1)
+ return hop; // don't set this visited... we will visit this next
+ Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+ pow.setVisited();
+ return pow;
+ }
+
+ /**
+ * A Comparator that orders Hops by their data type, dimention, and sparsity.
+ * The order is as follows:
+ * scalars > row vectors > col vectors >
+ * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+ * other data types.
+ * Disambiguate by Hop ID.
+ */
+ private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
+ @Override
+ public final int compare(Hop o1, Hop o2) {
+ int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
+ if (c != 0) return c;
+
+ // o1 and o2 have the same data type
+ switch (o1.getDataType()) {
+ case MATRIX:
+ // two matrices; check for vectors
+ if (o1.getDim1() == 1) { // row vector
+ if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
+ return compareBySparsityThenId(o1, o2); // both row vectors
+ } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
+ return -1; // row vectors are the greatest matrices
+ } else if (o1.getDim2() == 1) { // col vector
+ if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
+ return compareBySparsityThenId(o1, o2); // both col vectors
+ } else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
+ return 1; // col vectors greater than non-vectors
+ } else { // both non-vectors
+ return compareBySparsityThenId(o1, o2);
+ }
+ default:
+ return Long.compare(o1.getHopID(), o2.getHopID());
+ }
+ }
+ private int compareBySparsityThenId(Hop o1, Hop o2) {
+ // the hop with more nnz is first; unknown nnz (-1) last
+ int c = Long.compare(o1.getNnz(), o2.getNnz());
+ if (c != 0) return c;
+ return Long.compare(o1.getHopID(), o2.getHopID());
+ }
+ private final int[] orderDataType;
+ {
+ Expression.DataType[] dtValues = Expression.DataType.values();
+ orderDataType = new int[dtValues.length];
+ for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
+ switch(dtValues[i]) {
+ case SCALAR:
+ orderDataType[i] = 4;
+ break;
+ case MATRIX:
+ orderDataType[i] = 3;
+ break;
+ case FRAME:
+ orderDataType[i] = 2;
+ break;
+ case OBJECT:
+ orderDataType[i] = 1;
+ break;
+ case UNKNOWN:
+ orderDataType[i] = 0;
+ break;
+ }
+ }
+ }
+ };
+
+ /**
+ * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
+ * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+ * @param child An interior emult hop in the emult chain dag.
+ * @return Whether this interior emult or any child emult has a foreign parent.
+ */
+ private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
+ final ArrayList<Hop> parents = child.getParent();
+ if (parents.size() > 1)
+ for (final Hop parent : parents)
+ //noinspection SuspiciousMethodCalls (for Intellij, which checks when
+ if (!emults.contains(parent))
+ return false;
+ // child does not have foreign parents
+
+ final ArrayList<Hop> inputs = child.getInput();
+ final Hop left = inputs.get(0), right = inputs.get(1);
+ return (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
+ (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
+ }
+
+ /**
+ * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively.
+ * @param root Root of sub-dag
+ * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
+ * @param leaves Out parameter. The multiset of multiplicands in the emult chain.
+ */
+ private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+ // Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
+ emults.add(root);
+
+ final ArrayList<Hop> inputs = root.getInput();
+ final Hop left = inputs.get(0), right = inputs.get(1);
+
+ if (isBinaryMult(left))
+ findEMultsAndLeaves((BinaryOp) left, emults, leaves);
+ else
+ leaves.add(left);
+
+ if (isBinaryMult(right))
+ findEMultsAndLeaves((BinaryOp) right, emults, leaves);
+ else
+ leaves.add(right);
+ }
+
+ /**
+ * Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
+ * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+ * @param leaves The multiset of multiplicands in the emult chain.
+ * @return If the multiset is worth optimizing.
+ */
+ private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+ return emults.size() >= 2;
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
deleted file mode 100644
index 85dbea4..0000000
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.test.integration.functions.misc;
-
-import java.util.HashMap;
-
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.lops.LopProperties.ExecType;
-import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.apache.sysml.test.integration.TestConfiguration;
-import org.apache.sysml.test.utils.TestUtils;
-import org.junit.Assert;
-import org.junit.Test;
-
-/**
- * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
- */
-public class RewriteEMultChainTest extends AutomatedTestBase
-{
- private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
- private static final String TEST_DIR = "functions/misc/";
- private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEMultChainTest.class.getSimpleName() + "/";
-
- private static final int rows = 123;
- private static final int cols = 321;
- private static final double eps = Math.pow(10, -10);
-
- @Override
- public void setUp() {
- TestUtils.clearAssertionInformation();
- addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
- }
-
- @Test
- public void testMatrixMultChainOptNoRewritesCP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
- }
-
- @Test
- public void testMatrixMultChainOptNoRewritesSP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
- }
-
- @Test
- public void testMatrixMultChainOptRewritesCP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
- }
-
- @Test
- public void testMatrixMultChainOptRewritesSP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
- }
-
- private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
- {
- RUNTIME_PLATFORM platformOld = rtplatform;
- switch( et ){
- case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
- case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
- default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
- }
-
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == RUNTIME_PLATFORM.SPARK )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
- boolean rewritesOld = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
- OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
-
- try
- {
- TestConfiguration config = getTestConfiguration(testname);
- loadTestConfiguration(config);
-
- String HOME = SCRIPT_DIR + TEST_DIR;
- fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
- fullRScriptName = HOME + testname + ".R";
- rCmd = getRCmd(inputDir(), expectedDir());
-
- double Xsparsity = 0.8, Ysparsity = 0.6;
- double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
- double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
- writeInputMatrixWithMTD("X", X, true);
- writeInputMatrixWithMTD("Y", Y, true);
-
- //execute tests
- runTest(true, false, null, -1);
- runRScript(true);
-
- //compare matrices
- HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
- HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
- TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
-
- //check for presence of power operator, if we did a rewrite
- if( rewrites ) {
- Assert.assertTrue(heavyHittersContainsSubString("^2"));
- }
- }
- finally {
- OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOld;
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
new file mode 100644
index 0000000..47b2f0e
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
+ */
+public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
+ private static final String TEST_DIR = "functions/misc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationChainTest.class.getSimpleName() + "/";
+
+ private static final int rows = 123;
+ private static final int cols = 321;
+ private static final double eps = Math.pow(10, -10);
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+ }
+
+ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
+ {
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( et ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+
+ try
+ {
+ TestConfiguration config = getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ double Xsparsity = 0.8, Ysparsity = 0.6;
+ double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+ double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
+ writeInputMatrixWithMTD("X", X, true);
+ writeInputMatrixWithMTD("Y", Y, true);
+
+ //execute tests
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
+ TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+
+ //check for presence of power operator, if we did a rewrite
+ if( rewrites ) {
+ Assert.assertTrue(heavyHittersContainsSubString("^2"));
+ }
+ }
+ finally {
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
index 1829bf0..460829d 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
@@ -24,6 +24,7 @@ import java.util.HashMap;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.rewrite.RewriteElementwiseMultChainOptimization;
import org.apache.sysml.lops.LopProperties.ExecType;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
@@ -36,7 +37,7 @@ import org.junit.Test;
/**
* Similar to {@link TernaryAggregateTest} except that it tests `sum(A*B*A)`.
- * Checks compatibility with {@link org.apache.sysml.hops.rewrite.RewriteEMult}.
+ * Checks compatibility with {@link RewriteElementwiseMultChainOptimization}.
*/
public class ABATernaryAggregateTest extends AutomatedTestBase
{
@@ -368,14 +369,14 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES,
- rewritesOldEmult = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
+ rewritesOldEmult = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
try {
TestConfiguration config = getTestConfiguration(testname);
loadTestConfiguration(config);
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
- OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
@@ -411,7 +412,7 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
- OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOldEmult;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOldEmult;
}
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index e352e6d..deea784 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -50,6 +50,7 @@ import org.junit.runners.Suite;
ReadAfterWriteTest.class,
RewriteCSETransposeScalarTest.class,
RewriteCTableToRExpandTest.class,
+ RewriteElementwiseMultChainOptimizationChainTest.class,
RewriteEliminateAggregatesTest.class,
RewriteFuseBinaryOpChainTest.class,
RewriteFusedRandTest.class,
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
index 784177d..ee14359 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
@@ -22,10 +22,11 @@ package org.apache.sysml.test.integration.functions.ternary;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
-/** Group together the tests in this package into a single suite so that the Maven build
+/* Group together the tests in this package into a single suite so that the Maven build
* won't run two of them at once. */
@RunWith(Suite.class)
@Suite.SuiteClasses({
+ ABATernaryAggregateTest.class,
CentralMomentWeightsTest.class,
CovarianceWeightsTest.class,
CTableMatrixIgnoreZerosTest.class,
[05/23] systemml git commit: TernaryAggregate now applies to a power
of 3.
Posted by mb...@apache.org.
TernaryAggregate now applies to a power of 3.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f005d949
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f005d949
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f005d949
Branch: refs/heads/master
Commit: f005d94997d9c17ad8e90b4d2bd340f81b9a752d
Parents: 8b832f6
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 22:06:10 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:24 2017 -0700
----------------------------------------------------------------------
.../java/org/apache/sysml/hops/AggUnaryOp.java | 67 ++++++++++++--------
.../functions/misc/RewriteEMultChainTest.java | 7 +-
.../functions/misc/RewriteEMultChainOp.R | 33 ----------
.../functions/misc/RewriteEMultChainOp.dml | 28 --------
.../functions/misc/RewriteEMultChainOpXYX.R | 33 ++++++++++
.../functions/misc/RewriteEMultChainOpXYX.dml | 28 ++++++++
6 files changed, 106 insertions(+), 90 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 4573b66..300a20c 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -490,29 +490,35 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
(_direction == Direction.RowCol || _direction == Direction.Col) )
{
Hop input1 = getInput().get(0);
- if( input1.getParent().size() == 1 && //sum single consumer
- input1 instanceof BinaryOp && ((BinaryOp)input1).getOp()==OpOp2.MULT
- // As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it.
- && input1.optFindExecType() != ExecType.MR)
- {
- Hop input11 = input1.getInput().get(0);
- Hop input12 = input1.getInput().get(1);
-
- if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT ) {
- //ternary, arbitrary matrices but no mv/outer operations.
- ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1)
- && HopRewriteUtils.isEqualSize(input11.getInput().get(1), input1)
- && HopRewriteUtils.isEqualSize(input12, input1);
- }
- else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT ) {
- //ternary, arbitrary matrices but no mv/outer operations.
- ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1)
- && HopRewriteUtils.isEqualSize(input12.getInput().get(1), input1)
- && HopRewriteUtils.isEqualSize(input11, input1);
+ if (input1.getParent().size() == 1
+ && input1 instanceof BinaryOp) { //sum single consumer
+ BinaryOp binput1 = (BinaryOp)input1;
+
+ if (binput1.getOp() == OpOp2.POW
+ && binput1.getInput().get(1) instanceof LiteralOp) {
+ LiteralOp lit = (LiteralOp)binput1.getInput().get(1);
+ ret = lit.getLongValue() == 3;
}
- else {
- //binary, arbitrary matrices but no mv/outer operations.
- ret = HopRewriteUtils.isEqualSize(input11, input12);
+ else if (binput1.getOp() == OpOp2.MULT
+ // As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it.
+ && input1.optFindExecType() != ExecType.MR) {
+ Hop input11 = input1.getInput().get(0);
+ Hop input12 = input1.getInput().get(1);
+
+ if (input11 instanceof BinaryOp && ((BinaryOp) input11).getOp() == OpOp2.MULT) {
+ //ternary, arbitrary matrices but no mv/outer operations.
+ ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1) && HopRewriteUtils
+ .isEqualSize(input11.getInput().get(1), input1) && HopRewriteUtils
+ .isEqualSize(input12, input1);
+ } else if (input12 instanceof BinaryOp && ((BinaryOp) input12).getOp() == OpOp2.MULT) {
+ //ternary, arbitrary matrices but no mv/outer operations.
+ ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1) && HopRewriteUtils
+ .isEqualSize(input12.getInput().get(1), input1) && HopRewriteUtils
+ .isEqualSize(input11, input1);
+ } else {
+ //binary, arbitrary matrices but no mv/outer operations.
+ ret = HopRewriteUtils.isEqualSize(input11, input12);
+ }
}
}
}
@@ -626,14 +632,25 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
private Lop constructLopsTernaryAggregateRewrite(ExecType et)
throws HopsException, LopsException
{
- Hop input1 = getInput().get(0);
+ BinaryOp input1 = (BinaryOp)getInput().get(0);
Hop input11 = input1.getInput().get(0);
Hop input12 = input1.getInput().get(1);
Lop in1 = null, in2 = null, in3 = null;
boolean handled = false;
-
- if( input11 instanceof BinaryOp ) {
+
+ if (input1.getOp() == OpOp2.POW) {
+ switch ((int)((LiteralOp)input12).getLongValue()) {
+ case 3:
+ in1 = input11.constructLops();
+ in2 = in1;
+ in3 = in1;
+ break;
+ default:
+ throw new AssertionError("unreachable; only applies to power 3");
+ }
+ handled = true;
+ } else if (input11 instanceof BinaryOp ) {
BinaryOp b11 = (BinaryOp)input11;
switch (b11.getOp()) {
case MULT: // A*B*C case
http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
index 18ed55d..85dbea4 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
@@ -37,7 +37,7 @@ import org.junit.Test;
*/
public class RewriteEMultChainTest extends AutomatedTestBase
{
- private static final String TEST_NAME1 = "RewriteEMultChainOp";
+ private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
private static final String TEST_DIR = "functions/misc/";
private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEMultChainTest.class.getSimpleName() + "/";
@@ -94,8 +94,7 @@ public class RewriteEMultChainTest extends AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[]{ "-explain", "hops", "-stats",
- "-args", input("X"), input("Y"), output("R") };
+ programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
fullRScriptName = HOME + testname + ".R";
rCmd = getRCmd(inputDir(), expectedDir());
@@ -104,7 +103,7 @@ public class RewriteEMultChainTest extends AutomatedTestBase
double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
writeInputMatrixWithMTD("X", X, true);
writeInputMatrixWithMTD("Y", Y, true);
-
+
//execute tests
runTest(true, false, null, -1);
runRScript(true);
http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOp.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.R b/src/test/scripts/functions/misc/RewriteEMultChainOp.R
deleted file mode 100644
index 6d94cc8..0000000
--- a/src/test/scripts/functions/misc/RewriteEMultChainOp.R
+++ /dev/null
@@ -1,33 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-
-args <- commandArgs(TRUE)
-options(digits=22)
-library("Matrix")
-library("matrixStats")
-
-X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
-Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
-
-R = X * Y * X;
-
-writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
deleted file mode 100644
index 3992403..0000000
--- a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
+++ /dev/null
@@ -1,28 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-
-X = read($1);
-Y = read($2);
-
-R = X * Y * X;
-
-write(R, $3);
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
new file mode 100644
index 0000000..6d94cc8
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
+
+R = X * Y * X;
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
new file mode 100644
index 0000000..3992403
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1);
+Y = read($2);
+
+R = X * Y * X;
+
+write(R, $3);
\ No newline at end of file
[06/23] systemml git commit: Add tests for sum(A*A*A).
Posted by mb...@apache.org.
Add tests for sum(A*A*A).
Reduced number of rows and columns in order to decrease test time.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/edbac3b6
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/edbac3b6
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/edbac3b6
Branch: refs/heads/master
Commit: edbac3b6c4361799da32cef89c8ee4e29e187c9d
Parents: f005d94
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 22:23:02 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:26 2017 -0700
----------------------------------------------------------------------
.../ternary/ABATernaryAggregateTest.java | 159 ++++++++++++++++++-
.../functions/ternary/AAATernaryAggregateC.R | 31 ++++
.../functions/ternary/AAATernaryAggregateC.dml | 28 ++++
.../functions/ternary/AAATernaryAggregateRC.R | 32 ++++
.../functions/ternary/AAATernaryAggregateRC.dml | 29 ++++
.../functions/ternary/ABATernaryAggregateC.dml | 1 -
6 files changed, 274 insertions(+), 6 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/edbac3b6/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
index 198e9f4..1829bf0 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
@@ -42,13 +42,15 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
{
private final static String TEST_NAME1 = "ABATernaryAggregateRC";
private final static String TEST_NAME2 = "ABATernaryAggregateC";
+ private final static String TEST_NAME3 = "AAATernaryAggregateRC";
+ private final static String TEST_NAME4 = "AAATernaryAggregateC";
private final static String TEST_DIR = "functions/ternary/";
private final static String TEST_CLASS_DIR = TEST_DIR + ABATernaryAggregateTest.class.getSimpleName() + "/";
private final static double eps = 1e-8;
- private final static int rows = 1111;
- private final static int cols = 1011;
+ private final static int rows = 111;
+ private final static int cols = 101;
private final static double sparsity1 = 0.7;
private final static double sparsity2 = 0.3;
@@ -57,7 +59,9 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
- addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+ addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+ addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
+ addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
}
@Test
@@ -201,6 +205,151 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
public void testTernaryAggregateCSparseMatrixCPNoRewrite() {
runTernaryAggregateTest(TEST_NAME2, true, false, false, ExecType.CP);
}
+
+
+ // another set of tests for the case of sum(A*A*A)
+
+ @Test
+ public void testTernaryAggregateRCDenseVectorCP_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, false, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseVectorCP_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, true, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCDenseMatrixCP_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, false, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseMatrixCP_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, true, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCDenseVectorSP_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, false, true, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseVectorSP_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, true, true, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateRCDenseMatrixSP_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, false, false, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseMatrixSP_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, true, false, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateRCDenseVectorMR_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, false, true, true, ExecType.MR);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseVectorMR_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, true, true, true, ExecType.MR);
+ }
+
+ @Test
+ public void testTernaryAggregateRCDenseMatrixMR_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, false, false, true, ExecType.MR);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseMatrixMR_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, true, false, true, ExecType.MR);
+ }
+
+ @Test
+ public void testTernaryAggregateCDenseVectorCP_AAA() {
+ runTernaryAggregateTest(TEST_NAME4, false, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCSparseVectorCP_AAA() {
+ runTernaryAggregateTest(TEST_NAME4, true, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCDenseMatrixCP_AAA() {
+ runTernaryAggregateTest(TEST_NAME4, false, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCSparseMatrixCP_AAA() {
+ runTernaryAggregateTest(TEST_NAME4, true, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCDenseVectorSP_AAA() {
+ runTernaryAggregateTest(TEST_NAME4, false, true, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateCSparseVectorSP_AAA() {
+ runTernaryAggregateTest(TEST_NAME4, true, true, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateCDenseMatrixSP_AAA() {
+ runTernaryAggregateTest(TEST_NAME4, false, false, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateCSparseMatrixSP_AAA() {
+ runTernaryAggregateTest(TEST_NAME4, true, false, true, ExecType.SPARK);
+ }
+
+ //additional tests to check default without rewrites
+
+ @Test
+ public void testTernaryAggregateRCDenseVectorCPNoRewrite_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, false, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseVectorCPNoRewrite_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, true, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCDenseMatrixCPNoRewrite_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, false, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRCSparseMatrixCPNoRewrite_AAA() {
+ runTernaryAggregateTest(TEST_NAME3, true, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCDenseVectorCPNoRewrite_AAA() {
+ runTernaryAggregateTest(TEST_NAME4, false, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCSparseVectorCPNoRewrite_AAA() {
+ runTernaryAggregateTest(TEST_NAME4, true, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCDenseMatrixCPNoRewrite_AAA() {
+ runTernaryAggregateTest(TEST_NAME4, false, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateCSparseMatrixCPNoRewrite_AAA() {
+ runTernaryAggregateTest(TEST_NAME4, true, false, false, ExecType.CP);
+ }
@@ -230,7 +379,7 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[]{"-explain","hops","-stats","-args", input("A"), output("R")};
+ programArgs = new String[]{"-explain","-stats","-args", input("A"), output("R")};
fullRScriptName = HOME + testname + ".R";
rCmd = "Rscript" + " " + fullRScriptName + " " +
@@ -254,7 +403,7 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
//check for rewritten patterns in statistics output
if( rewrites && et != ExecType.MR ) {
String opcode = ((et == ExecType.SPARK) ? Instruction.SP_INST_PREFIX : "") +
- (((testname.equals(TEST_NAME1) || vectors ) ? "tak+*" : "tack+*"));
+ (((testname.equals(TEST_NAME1) || testname.equals(TEST_NAME3) || vectors ) ? "tak+*" : "tack+*"));
Assert.assertTrue(Statistics.getCPHeavyHitterOpCodes().contains(opcode));
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/edbac3b6/src/test/scripts/functions/ternary/AAATernaryAggregateC.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/AAATernaryAggregateC.R b/src/test/scripts/functions/ternary/AAATernaryAggregateC.R
new file mode 100644
index 0000000..a096c2b
--- /dev/null
+++ b/src/test/scripts/functions/ternary/AAATernaryAggregateC.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+
+R = t(as.matrix(colSums(A * A * A)));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/edbac3b6/src/test/scripts/functions/ternary/AAATernaryAggregateC.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/AAATernaryAggregateC.dml b/src/test/scripts/functions/ternary/AAATernaryAggregateC.dml
new file mode 100644
index 0000000..b576a4d
--- /dev/null
+++ b/src/test/scripts/functions/ternary/AAATernaryAggregateC.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+
+if(1==1){}
+
+R = colSums(A * A * A);
+
+write(R, $2);
http://git-wip-us.apache.org/repos/asf/systemml/blob/edbac3b6/src/test/scripts/functions/ternary/AAATernaryAggregateRC.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/AAATernaryAggregateRC.R b/src/test/scripts/functions/ternary/AAATernaryAggregateRC.R
new file mode 100644
index 0000000..776ddd0
--- /dev/null
+++ b/src/test/scripts/functions/ternary/AAATernaryAggregateRC.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+
+s = sum(A * A * A);
+R = as.matrix(s);
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/edbac3b6/src/test/scripts/functions/ternary/AAATernaryAggregateRC.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/AAATernaryAggregateRC.dml b/src/test/scripts/functions/ternary/AAATernaryAggregateRC.dml
new file mode 100644
index 0000000..7283703
--- /dev/null
+++ b/src/test/scripts/functions/ternary/AAATernaryAggregateRC.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+
+if(1==1){}
+
+s = sum(A * A * A);
+R = as.matrix(s);
+
+write(R, $2);
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/edbac3b6/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
index 78285af..737b409 100644
--- a/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
@@ -21,7 +21,6 @@
A = read($1);
B = A * 2;
-C = A * 3;
if(1==1){}
[21/23] systemml git commit: Add new `wumm` pattern to pick up
element-wise multiply rewrite.
Posted by mb...@apache.org.
Add new `wumm` pattern to pick up element-wise multiply rewrite.
The new pattern recognizes when there is a `*2` or `2*` outside `W*(U%*%t(V))`.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/479b9da4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/479b9da4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/479b9da4
Branch: refs/heads/master
Commit: 479b9da4e6c605871a914ccb4b06ab6da5de21ed
Parents: e93c487
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Thu Jul 13 01:14:48 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Thu Jul 13 01:14:48 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/rewrite/ProgramRewriter.java | 2 +-
.../RewriteAlgebraicSimplificationDynamic.java | 44 +++++++++++++++++++-
2 files changed, 43 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/479b9da4/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 59565df..7c4f861 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -54,7 +54,7 @@ public class ProgramRewriter
private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName());
//internal local debug level
- private static final boolean LDEBUG = false;
+ private static final boolean LDEBUG = false;
private static final boolean CHECK = false;
private ArrayList<HopRewriteRule> _dagRuleSet = null;
http://git-wip-us.apache.org/repos/asf/systemml/blob/479b9da4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 8cd71f4..6246270 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -29,11 +29,11 @@ import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataGenOp;
import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.QuaternaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.DataGenMethod;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.Hop.OpOp3;
import org.apache.sysml.hops.Hop.OpOp4;
import org.apache.sysml.hops.Hop.ParamBuiltinOp;
@@ -44,7 +44,7 @@ import org.apache.sysml.hops.LeftIndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
-import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.QuaternaryOp;
import org.apache.sysml.hops.ReorgOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
@@ -1959,6 +1959,46 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")");
}
+
+ //Pattern 1.5) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
+ if( !appliedPattern
+ && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT)
+ && (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2)
+ || HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2)))
+ {
+ final Hop nl; // non-literal
+ if( hi.getInput().get(0) instanceof LiteralOp ) {
+ nl = hi.getInput().get(1);
+ } else {
+ nl = hi.getInput().get(0);
+ }
+
+ if ( HopRewriteUtils.isBinary(nl, OpOp2.MULT)
+ && HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) //prevent mv
+ && nl.getDim2() > 1 //not applied for vector-vector mult
+ && nl.getInput().get(0).getDataType() == DataType.MATRIX
+ && nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock()
+ && HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1))
+ && (((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
+ && HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0),true) )
+ {
+ final Hop W = nl.getInput().get(0);
+ final Hop U = nl.getInput().get(1).getInput().get(0);
+ Hop V = nl.getInput().get(1).getInput().get(1);
+ if( !HopRewriteUtils.isTransposeOperation(V) )
+ V = HopRewriteUtils.createTranspose(V);
+ else
+ V = V.getInput().get(0);
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
+ hnew.refreshSizeInformation();
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line "+hi.getBeginLine()+")");
+ }
+ }
//Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
if( !appliedPattern
[19/23] systemml git commit: Move to dynamic rewrites. Do not rewrite
if top-level dims unknown.
Posted by mb...@apache.org.
Move to dynamic rewrites. Do not rewrite if top-level dims unknown.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d18a4c80
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d18a4c80
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d18a4c80
Branch: refs/heads/master
Commit: d18a4c80dece566ddbad34a7f3c2f70ce544023e
Parents: c4e9228
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 22:54:31 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 22:54:31 2017 -0700
----------------------------------------------------------------------
.../java/org/apache/sysml/hops/rewrite/ProgramRewriter.java | 6 +++---
.../hops/rewrite/RewriteElementwiseMultChainOptimization.java | 5 +++--
2 files changed, 6 insertions(+), 5 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/d18a4c80/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index a1ff5bc..59565df 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,8 +96,6 @@ public class ProgramRewriter
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
- if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
- _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
@@ -125,7 +123,9 @@ public class ProgramRewriter
// DYNAMIC REWRITES (which do require size information)
if( dynamicRewrites )
{
- _dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse
+ _dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse
+ if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
+ _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
{
http://git-wip-us.apache.org/repos/asf/systemml/blob/d18a4c80/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index de1def8..1f85bbf 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -44,6 +44,7 @@ import org.apache.sysml.parser.Expression;
*
* Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
* since foreign parents may rely on the individual results.
+ * Does not perform rewrites on an element-wise multiply if its dimensions are unknown.
*
* The new order of element-wise multiply chains is as follows:
* <pre>
@@ -81,8 +82,8 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
return root;
root.setVisited();
- // 1. Find immediate subtree of EMults.
- if (isBinaryMult(root)) {
+ // 1. Find immediate subtree of EMults. Check dimsKnown.
+ if (isBinaryMult(root) && root.dimsKnown()) {
final Hop left = root.getInput().get(0), right = root.getInput().get(1);
final Set<BinaryOp> emults = new HashSet<>();
final Map<Hop, Integer> leaves = new HashMap<>(); // poor man's HashMultiset
[18/23] systemml git commit: Get rid of leftover Guava dependency
Posted by mb...@apache.org.
Get rid of leftover Guava dependency
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c4e9228e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c4e9228e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c4e9228e
Branch: refs/heads/master
Commit: c4e9228ed0b86789f8b41a533bf112d681a30318
Parents: 3c4d777
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 22:46:33 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 22:46:33 2017 -0700
----------------------------------------------------------------------
...RewriteElementwiseMultChainOptimization.java | 32 +++++++++++---------
1 file changed, 17 insertions(+), 15 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/c4e9228e/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 486072b..de1def8 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.rewrite;
import java.util.ArrayList;
import java.util.Comparator;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
@@ -34,15 +35,12 @@ import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.parser.Expression;
-import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Multiset;
-
/**
* Prerequisite: RewriteCommonSubexpressionElimination must run before this rule.
*
* Rewrite a chain of element-wise multiply hops that contain identical elements.
* For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
- * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
+ * The order of the multiplicands depends on their data types, dimensions (matrix or vector), and sparsity.
*
* Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
* since foreign parents may rely on the individual results.
@@ -87,7 +85,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
if (isBinaryMult(root)) {
final Hop left = root.getInput().get(0), right = root.getInput().get(1);
final Set<BinaryOp> emults = new HashSet<>();
- final Multiset<Hop> leaves = HashMultiset.create();
+ final Map<Hop, Integer> leaves = new HashMap<>(); // poor man's HashMultiset
findEMultsAndLeaves((BinaryOp)root, emults, leaves);
// 2. Ensure it is profitable to do a rewrite.
@@ -109,7 +107,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
final Hop newRoot = HopRewriteUtils.rewireAllParentChildReferences(root, replacement);
// 6. Recurse at leaves (no need to repeat the interior emults)
- for (final Hop leaf : leaves.elementSet()) {
+ for (final Hop leaf : leaves.keySet()) {
recurseInputs(leaf);
}
return newRoot;
@@ -131,15 +129,15 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
}
}
- private static Hop constructReplacement(final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+ private static Hop constructReplacement(final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
// Sort by data type
final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
- for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
- final Hop h = entry.getElement();
+ for (final Map.Entry<Hop, Integer> entry : leaves.entrySet()) {
+ final Hop h = entry.getKey();
// unlink parents that are in the emult set(we are throwing them away)
// keep other parents
h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent));
- sorted.put(h, entry.getCount());
+ sorted.put(h, entry.getValue());
}
// sorted contains all leaves, sorted by data type, stripped from their parents
@@ -244,7 +242,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
}
/**
- * A Comparator that orders Hops by their data type, dimention, and sparsity.
+ * A Comparator that orders Hops by their data type, dimension, and sparsity.
* The order is as follows:
* scalars < col vectors < row vectors <
* non-vector matrices ordered by sparsity (higher nnz last, unknown sparsity last) >
@@ -324,7 +322,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
* @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
* @param leaves Out parameter. The multiset of multiplicands in the emult chain.
*/
- private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+ private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
emults.add(root);
@@ -334,12 +332,16 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
if (isBinaryMult(left))
findEMultsAndLeaves((BinaryOp) left, emults, leaves);
else
- leaves.add(left);
+ addMultiset(leaves, left);
if (isBinaryMult(right))
findEMultsAndLeaves((BinaryOp) right, emults, leaves);
else
- leaves.add(right);
+ addMultiset(leaves, right);
+ }
+
+ private static <K> void addMultiset(final Map<K,Integer> map, final K k) {
+ map.put(k, map.getOrDefault(k, 0) + 1);
}
/**
@@ -348,7 +350,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
* @param leaves The multiset of multiplicands in the emult chain.
* @return If the multiset is worth optimizing.
*/
- private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+ private static boolean isOptimizable(Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
return emults.size() >= 2;
}
}
[03/23] systemml git commit: Add name to sorting of EMult rewrite.
Handle Ternary A*A*B case.
Posted by mb...@apache.org.
Add name to sorting of EMult rewrite. Handle Ternary A*A*B case.
AggUnaryOp now constructs the TernaryOperator (A,A,B) instead of (A^2,B,1).
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ff8c836c
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ff8c836c
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ff8c836c
Branch: refs/heads/master
Commit: ff8c836c7b736dbd7b7651ac792a6d8c23989c98
Parents: eb0599d
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 13:18:19 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:18 2017 -0700
----------------------------------------------------------------------
.../java/org/apache/sysml/hops/AggUnaryOp.java | 62 ++++++++++++++------
.../apache/sysml/hops/rewrite/RewriteEMult.java | 4 +-
2 files changed, 48 insertions(+), 18 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/ff8c836c/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index ee4ded2..4573b66 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -516,7 +516,6 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
}
}
}
-
return ret;
}
@@ -631,24 +630,53 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
Hop input11 = input1.getInput().get(0);
Hop input12 = input1.getInput().get(1);
- Lop in1 = null;
- Lop in2 = null;
- Lop in3 = null;
+ Lop in1 = null, in2 = null, in3 = null;
+ boolean handled = false;
- if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT )
- {
- in1 = input11.getInput().get(0).constructLops();
- in2 = input11.getInput().get(1).constructLops();
- in3 = input12.constructLops();
- }
- else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT )
- {
- in1 = input11.constructLops();
- in2 = input12.getInput().get(0).constructLops();
- in3 = input12.getInput().get(1).constructLops();
+ if( input11 instanceof BinaryOp ) {
+ BinaryOp b11 = (BinaryOp)input11;
+ switch (b11.getOp()) {
+ case MULT: // A*B*C case
+ in1 = input11.getInput().get(0).constructLops();
+ in2 = input11.getInput().get(1).constructLops();
+ in3 = input12.constructLops();
+ handled = true;
+ break;
+ case POW: // A*A*B case
+ Hop b112 = b11.getInput().get(1);
+ if ( !(input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT)
+ && b112 instanceof LiteralOp
+ && ((LiteralOp)b112).getLongValue() == 2) {
+ in1 = b11.getInput().get(0).constructLops();
+ in2 = in1;
+ in3 = input12.constructLops();
+ handled = true;
+ }
+ break;
+ }
+ } else if( input12 instanceof BinaryOp ) {
+ BinaryOp b12 = (BinaryOp)input12;
+ switch (b12.getOp()) {
+ case MULT: // A*B*C case
+ in1 = input11.constructLops();
+ in2 = input12.getInput().get(0).constructLops();
+ in3 = input12.getInput().get(1).constructLops();
+ handled = true;
+ break;
+ case POW: // A*B*B case
+ Hop b112 = b12.getInput().get(1);
+ if ( b112 instanceof LiteralOp
+ && ((LiteralOp)b112).getLongValue() == 2) {
+ in1 = b12.getInput().get(0).constructLops();
+ in2 = in1;
+ in3 = input11.constructLops();
+ handled = true;
+ }
+ break;
+ }
}
- else
- {
+
+ if (!handled) {
in1 = input11.constructLops();
in2 = input12.constructLops();
in3 = new LiteralOp(1).constructLops();
http://git-wip-us.apache.org/repos/asf/systemml/blob/ff8c836c/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index 2c9e5cb..66da6fa 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -145,7 +145,9 @@ public class RewriteEMult extends HopRewriteRule {
return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
}
- private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType).thenComparing(Object::hashCode);
+ private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType)
+ .thenComparing(Hop::getName)
+ .thenComparingInt(Object::hashCode);
private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
final ArrayList<Hop> parents = child.getParent();
[23/23] systemml git commit: [SYSTEMML-1663] Fix and enable rewrite
element-wise multiply chains
Posted by mb...@apache.org.
[SYSTEMML-1663] Fix and enable rewrite element-wise multiply chains
Closes #567.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/85e3a963
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/85e3a963
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/85e3a963
Branch: refs/heads/master
Commit: 85e3a9631081e8f44a21e4c7e99bbcf03859cba7
Parents: 1b3dff0 14bd65a
Author: Matthias Boehm <mb...@gmail.com>
Authored: Fri Jul 14 21:10:18 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Fri Jul 14 21:10:19 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/rewrite/ProgramRewriter.java | 8 +-
.../RewriteAlgebraicSimplificationDynamic.java | 45 ++++-
...RewriteElementwiseMultChainOptimization.java | 189 +++++++++++++------
...ElementwiseMultChainOptimizationAllTest.java | 134 +++++++++++++
...iteElementwiseMultChainOptimizationTest.java | 4 +-
.../functions/misc/RewriteEMultChainOpAll.R | 37 ++++
.../functions/misc/RewriteEMultChainOpAll.dml | 31 +++
7 files changed, 379 insertions(+), 69 deletions(-)
----------------------------------------------------------------------
[11/23] systemml git commit: Review comments, part 2
Posted by mb...@apache.org.
Review comments, part 2
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d6d37952
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d6d37952
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d6d37952
Branch: refs/heads/master
Commit: d6d37952bdb1cd76d61c376b2051292ca272ee0a
Parents: 737f93b
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 11 02:27:34 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:45 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/rewrite/HopRewriteUtils.java | 23 +++++-------
...RewriteElementwiseMultChainOptimization.java | 38 +++++++-------------
.../java/org/apache/sysml/utils/Explain.java | 2 +-
3 files changed, 22 insertions(+), 41 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/d6d37952/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 8f71359..b98901a 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -246,22 +246,15 @@ public class HopRewriteUtils
* Replace an old Hop with a replacement Hop.
* If the old Hop has no parents, then return the replacement.
* Otherwise rewire each of the Hop's parents into the replacement and return the replacement.
- * @param old To be replaced
- * @param replacement The replacement
- * @return replacement
+ * @param hold To be replaced
+ * @param hnew The replacement
+ * @return hnew
*/
- public static Hop replaceHop(final Hop old, final Hop replacement) {
- if (old.getParent().isEmpty())
- return replacement; // new old!
- HopRewriteUtils.rewireAllParentChildReferences(old, replacement);
- return replacement;
- }
-
-
- public static void rewireAllParentChildReferences( Hop hold, Hop hnew ) {
- ArrayList<Hop> parents = new ArrayList<Hop>(hold.getParent());
- for( Hop lparent : parents )
- HopRewriteUtils.replaceChildReference(lparent, hold, hnew);
+ public static Hop rewireAllParentChildReferences( Hop hold, Hop hnew ) {
+ ArrayList<Hop> parents = hold.getParent();
+ while (!parents.isEmpty())
+ HopRewriteUtils.replaceChildReference(parents.get(0), hold, hnew);
+ return hnew;
}
public static void replaceChildReference( Hop parent, Hop inOld, Hop inNew ) {
http://git-wip-us.apache.org/repos/asf/systemml/blob/d6d37952/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 1dd5813..91b7306 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -98,7 +98,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
emults.size(), root.getHopID(), replacement.getHopID()));
// 5. Replace root with replacement
- final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
+ final Hop newRoot = HopRewriteUtils.rewireAllParentChildReferences(root, replacement);
// 6. Recurse at leaves (no need to repeat the interior emults)
for (final Hop leaf : leaves.elementSet()) {
@@ -166,6 +166,18 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
* Disambiguate by Hop ID.
*/
private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
+ private final int[] orderDataType = new int[Expression.DataType.values().length];
+ {
+ for (int i = 0, valuesLength = Expression.DataType.values().length; i < valuesLength; i++)
+ switch(Expression.DataType.values()[i]) {
+ case SCALAR: orderDataType[i] = 4; break;
+ case MATRIX: orderDataType[i] = 3; break;
+ case FRAME: orderDataType[i] = 2; break;
+ case OBJECT: orderDataType[i] = 1; break;
+ case UNKNOWN:orderDataType[i] = 0; break;
+ }
+ }
+
@Override
public final int compare(Hop o1, Hop o2) {
int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
@@ -198,30 +210,6 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
if (c != 0) return c;
return Long.compare(o1.getHopID(), o2.getHopID());
}
- private final int[] orderDataType;
- {
- Expression.DataType[] dtValues = Expression.DataType.values();
- orderDataType = new int[dtValues.length];
- for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
- switch(dtValues[i]) {
- case SCALAR:
- orderDataType[i] = 4;
- break;
- case MATRIX:
- orderDataType[i] = 3;
- break;
- case FRAME:
- orderDataType[i] = 2;
- break;
- case OBJECT:
- orderDataType[i] = 1;
- break;
- case UNKNOWN:
- orderDataType[i] = 0;
- break;
- }
- }
- }
}.reversed();
/**
http://git-wip-us.apache.org/repos/asf/systemml/blob/d6d37952/src/main/java/org/apache/sysml/utils/Explain.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Explain.java b/src/main/java/org/apache/sysml/utils/Explain.java
index 450c6e5..6451396 100644
--- a/src/main/java/org/apache/sysml/utils/Explain.java
+++ b/src/main/java/org/apache/sysml/utils/Explain.java
@@ -76,7 +76,7 @@ public class Explain
//internal configuration parameters
private static final boolean REPLACE_SPECIAL_CHARACTERS = true;
private static final boolean SHOW_MEM_ABOVE_BUDGET = true;
- private static final boolean SHOW_LITERAL_HOPS = true;
+ private static final boolean SHOW_LITERAL_HOPS = false;
private static final boolean SHOW_DATA_DEPENDENCIES = true;
private static final boolean SHOW_DATA_FLOW_PROPERTIES = true;
[08/23] systemml git commit: Document RewriteEMult. Add smart
recursion.
Posted by mb...@apache.org.
Document RewriteEMult. Add smart recursion.
RewriteEMult now rewrites emult chains deeper than the top-most one.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d88f867f
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d88f867f
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d88f867f
Branch: refs/heads/master
Commit: d88f867fd0384954dce9e6ce4d65f02f1054bc5e
Parents: a5846bb
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sat Jun 10 01:17:36 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:33 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/rewrite/HopRewriteUtils.java | 2 +
.../apache/sysml/hops/rewrite/RewriteEMult.java | 90 +++++++++++++-------
.../org/apache/sysml/parser/Expression.java | 1 -
3 files changed, 62 insertions(+), 31 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 4d23cb9..17ac4ec 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -246,6 +246,8 @@ public class HopRewriteUtils
* Replace an old Hop with a replacement Hop.
* If the old Hop has no parents, then return the replacement.
* Otherwise rewire each of the Hop's parents into the replacement and return the replacement.
+ * @param old To be replaced
+ * @param replacement The replacement
* @return replacement
*/
public static Hop replaceHop(final Hop old, final Hop replacement) {
http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index d483a08..5cd1471 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -42,6 +42,7 @@ import com.google.common.collect.Multiset;
*
* Rewrite a chain of element-wise multiply hops that contain identical elements.
* For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
+ * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
*
* Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
* since foreign parents may rely on the individual results.
@@ -74,18 +75,15 @@ public class RewriteEMult extends HopRewriteRule {
return root;
root.setVisited();
- final ArrayList<Hop> rootInputs = root.getInput();
-
// 1. Find immediate subtree of EMults.
if (isBinaryMult(root)) {
- final Hop left = rootInputs.get(0), right = rootInputs.get(1);
- final BinaryOp r = (BinaryOp)root;
+ final Hop left = root.getInput().get(0), right = root.getInput().get(1);
final Set<BinaryOp> emults = new HashSet<>();
final Multiset<Hop> leaves = HashMultiset.create();
- findEMultsAndLeaves(r, emults, leaves);
+ findEMultsAndLeaves((BinaryOp)root, emults, leaves);
// 2. Ensure it is profitable to do a rewrite.
- if (isOptimizable(leaves)) {
+ if (isOptimizable(emults, leaves)) {
// 3. Check for foreign parents.
// A foreign parent is a parent of some EMult that is not in the set.
// Foreign parents destroy correctness of this rewrite.
@@ -94,25 +92,35 @@ public class RewriteEMult extends HopRewriteRule {
if (okay) {
// 4. Construct replacement EMults for the leaves
final Hop replacement = constructReplacement(leaves);
- // 5. Replace root with replacement
if (LOG.isDebugEnabled())
LOG.debug(String.format(
"Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
emults.size(), root.getHopID(), replacement.getHopID()));
- replacement.setVisited();
- return HopRewriteUtils.replaceHop(root, replacement);
+
+ // 5. Replace root with replacement
+ final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
+
+ // 6. Recurse at leaves (no need to repeat the interior emults)
+ for (final Hop leaf : leaves.elementSet()) {
+ recurseInputs(leaf);
+ }
+ return newRoot;
}
}
}
-
// This rewrite is not applicable to the current root.
// Try the root's children.
- for (int i = 0; i < rootInputs.size(); i++) {
- final Hop input = rootInputs.get(i);
+ recurseInputs(root);
+ return root;
+ }
+
+ private static void recurseInputs(final Hop parent) {
+ final ArrayList<Hop> inputs = parent.getInput();
+ for (int i = 0; i < inputs.size(); i++) {
+ final Hop input = inputs.get(i);
final Hop newInput = rule_RewriteEMult(input);
- rootInputs.set(i, newInput);
+ inputs.set(i, newInput);
}
- return root;
}
private static Hop constructReplacement(final Multiset<Hop> leaves) {
@@ -133,6 +141,7 @@ public class RewriteEMult extends HopRewriteRule {
for (int i = 1; i < sorted.size(); i++) {
final Hop second = constructPower(iterator.next());
first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
+ first.setVisited();
}
return first;
}
@@ -141,16 +150,21 @@ public class RewriteEMult extends HopRewriteRule {
final Hop hop = entry.getKey();
final int cnt = entry.getValue();
assert(cnt >= 1);
- if (cnt == 1) return hop;
- return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+ if (cnt == 1)
+ return hop; // don't set this visited... we will visit this next
+ Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+ pow.setVisited();
+ return pow;
}
-
-
- // Order: scalars > row vectors > col vectors >
- // non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
- // other data types
- // disambiguate by Hop ID
+ /**
+ * A Comparator that orders Hops by their data type, dimention, and sparsity.
+ * The order is as follows:
+ * scalars > row vectors > col vectors >
+ * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+ * other data types.
+ * Disambiguate by Hop ID.
+ */
private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
@Override
public final int compare(Hop o1, Hop o2) {
@@ -211,6 +225,12 @@ public class RewriteEMult extends HopRewriteRule {
}
};
+ /**
+ * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
+ * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+ * @param child An interior emult hop in the emult chain dag.
+ * @return Whether this interior emult or any child emult has a foreign parent.
+ */
private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
final ArrayList<Hop> parents = child.getParent();
if (parents.size() > 1)
@@ -227,7 +247,10 @@ public class RewriteEMult extends HopRewriteRule {
}
/**
- * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root.
+ * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively.
+ * @param root Root of sub-dag
+ * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
+ * @param leaves Out parameter. The multiset of multiplicands in the emult chain.
*/
private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
@@ -243,12 +266,19 @@ public class RewriteEMult extends HopRewriteRule {
else leaves.add(right);
}
- /** Only optimize a subtree of EMults if at least one leaf occurs more than once. */
- private static boolean isOptimizable(final Multiset<Hop> set) {
- for (Multiset.Entry<Hop> hopEntry : set.entrySet()) {
- if (hopEntry.getCount() > 1)
- return true;
- }
- return false;
+ /**
+ * Only optimize a subtree of emults if there are at least two emults.
+ * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+ * @param leaves The multiset of multiplicands in the emult chain.
+ * @return If the multiset is worth optimizing.
+ */
+ private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+ // Old criterion: there should be at least one repeated leaf
+// for (Multiset.Entry<Hop> hopEntry : leaves.entrySet()) {
+// if (hopEntry.getCount() > 1)
+// return true;
+// }
+// return false;
+ return emults.size() >= 2;
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index b944e29..9ee3fba 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -162,7 +162,6 @@ public abstract class Expression
* Data types (matrix, scalar, frame, object, unknown).
*/
public enum DataType {
- // Careful: the order of these enums is significant! See RewriteEMult.comparatorByDataType
MATRIX, SCALAR, FRAME, OBJECT, UNKNOWN;
public boolean isMatrix() {
[07/23] systemml git commit: simplifyDotProductSum shall not
interfere with tak+*
Posted by mb...@apache.org.
simplifyDotProductSum shall not interfere with tak+*
Added conditions to the dynamic algebraic rewrite simplifyDotProductSum
that do not apply the optimization for (A^2)*B or B*(A^2),
since TernaryAggregate handles these.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a5846bbb
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a5846bbb
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a5846bbb
Branch: refs/heads/master
Commit: a5846bbb383c655189963bffefed1c0db4ffcc89
Parents: edbac3b
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 23:58:16 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:30 2017 -0700
----------------------------------------------------------------------
.../hops/rewrite/RewriteAlgebraicSimplificationDynamic.java | 9 ++++++++-
1 file changed, 8 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/a5846bbb/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index ad80c05..166af2f 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2050,7 +2050,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
else if( HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) //no other consumer than sum
&& hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1
&& !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT)
- && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) )
+ && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT)
+ && !(HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) // do not rewrite (A^2)*B
+ && hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp // let tak+* handle it
+ && ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)
+ && !(HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) // do not rewrite B*(A^2)
+ && hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp // let tak+* handle it
+ && ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2)
+ )
{
baLeft = hi2.getInput().get(0);
baRight = hi2.getInput().get(1);
[14/23] systemml git commit: Review comments 3
Posted by mb...@apache.org.
Review comments 3
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/04f692df
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/04f692df
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/04f692df
Branch: refs/heads/master
Commit: 04f692dfcb25a032044dabb7064241073f959300
Parents: de469d2
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 18 16:54:51 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:54 2017 -0700
----------------------------------------------------------------------
.../java/org/apache/sysml/hops/AggUnaryOp.java | 4 +-
.../sysml/hops/rewrite/ProgramRewriter.java | 2 +-
...ementwiseMultChainOptimizationChainTest.java | 127 -------------------
...iteElementwiseMultChainOptimizationTest.java | 127 +++++++++++++++++++
.../functions/misc/ZPackageSuite.java | 2 +-
5 files changed, 132 insertions(+), 130 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index a207831..8e681c1 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -647,7 +647,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
handled = true;
} else if (input11 instanceof BinaryOp ) {
BinaryOp b11 = (BinaryOp)input11;
- switch (b11.getOp()) {
+ switch( b11.getOp() ) {
case MULT: // A*B*C case
in1 = input11.getInput().get(0).constructLops();
in2 = input11.getInput().get(1).constructLops();
@@ -664,6 +664,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
handled = true;
}
break;
+ default: break;
}
} else if( input12 instanceof BinaryOp ) {
BinaryOp b12 = (BinaryOp)input12;
@@ -683,6 +684,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
handled = true;
}
break;
+ default: break;
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 1053850..7ee3ccb 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -97,7 +97,7 @@ public class ProgramRewriter
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
- _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
+ _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
deleted file mode 100644
index e490750..0000000
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.test.integration.functions.misc;
-
-import java.util.HashMap;
-
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.lops.LopProperties.ExecType;
-import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.apache.sysml.test.integration.TestConfiguration;
-import org.apache.sysml.test.utils.TestUtils;
-import org.junit.Assert;
-import org.junit.Test;
-
-/**
- * Test whether `2*X*3*Y*4*X` successfully rewrites to `Y*(X^2)*24`.
- */
-public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedTestBase
-{
- private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
- private static final String TEST_DIR = "functions/misc/";
- private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationChainTest.class.getSimpleName() + "/";
-
- private static final int rows = 123;
- private static final int cols = 321;
- private static final double eps = Math.pow(10, -10);
-
- @Override
- public void setUp() {
- TestUtils.clearAssertionInformation();
- addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
- }
-
- @Test
- public void testMatrixMultChainOptNoRewritesCP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
- }
-
- @Test
- public void testMatrixMultChainOptNoRewritesSP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
- }
-
- @Test
- public void testMatrixMultChainOptRewritesCP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
- }
-
- @Test
- public void testMatrixMultChainOptRewritesSP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
- }
-
- private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
- {
- RUNTIME_PLATFORM platformOld = rtplatform;
- switch( et ){
- case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
- case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
- default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
- }
-
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == RUNTIME_PLATFORM.SPARK )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
- boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
- OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
-
- try
- {
- TestConfiguration config = getTestConfiguration(testname);
- loadTestConfiguration(config);
-
- String HOME = SCRIPT_DIR + TEST_DIR;
- fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
- fullRScriptName = HOME + testname + ".R";
- rCmd = getRCmd(inputDir(), expectedDir());
-
- double Xsparsity = 0.8, Ysparsity = 0.6;
- double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
- double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
- writeInputMatrixWithMTD("X", X, true);
- writeInputMatrixWithMTD("Y", Y, true);
-
- //execute tests
- runTest(true, false, null, -1);
- runRScript(true);
-
- //compare matrices
- HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
- HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
- TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
-
- //check for presence of power operator, if we did a rewrite
- if( rewrites ) {
- Assert.assertTrue(heavyHittersContainsSubString("^2"));
- }
- }
- finally {
- OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
new file mode 100644
index 0000000..91cb4e0
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test whether `2*X*3*Y*4*X` successfully rewrites to `Y*(X^2)*24`.
+ */
+public class RewriteElementwiseMultChainOptimizationTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
+ private static final String TEST_DIR = "functions/misc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationTest.class.getSimpleName() + "/";
+
+ private static final int rows = 123;
+ private static final int cols = 321;
+ private static final double eps = Math.pow(10, -10);
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+ }
+
+ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
+ {
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( et ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+
+ try
+ {
+ TestConfiguration config = getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ double Xsparsity = 0.8, Ysparsity = 0.6;
+ double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+ double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
+ writeInputMatrixWithMTD("X", X, true);
+ writeInputMatrixWithMTD("Y", Y, true);
+
+ //execute tests
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
+ TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+
+ //check for presence of power operator, if we did a rewrite
+ if( rewrites ) {
+ Assert.assertTrue(heavyHittersContainsSubString("^2"));
+ }
+ }
+ finally {
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index deea784..860cdbe 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -50,7 +50,7 @@ import org.junit.runners.Suite;
ReadAfterWriteTest.class,
RewriteCSETransposeScalarTest.class,
RewriteCTableToRExpandTest.class,
- RewriteElementwiseMultChainOptimizationChainTest.class,
+ RewriteElementwiseMultChainOptimizationTest.class,
RewriteEliminateAggregatesTest.class,
RewriteFuseBinaryOpChainTest.class,
RewriteFusedRandTest.class,
[10/23] systemml git commit: Add scalars to Rewrite Emult test
Posted by mb...@apache.org.
Add scalars to Rewrite Emult test
Not sure how to check this in an assert statement
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/737f93b1
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/737f93b1
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/737f93b1
Branch: refs/heads/master
Commit: 737f93b15a96aba31bc6c6da3651be309e3b8b0c
Parents: b94557f
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 11 01:56:07 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:41 2017 -0700
----------------------------------------------------------------------
.../hops/rewrite/RewriteElementwiseMultChainOptimization.java | 2 +-
src/main/java/org/apache/sysml/utils/Explain.java | 4 ++--
.../misc/RewriteElementwiseMultChainOptimizationChainTest.java | 4 ++--
.../integration/functions/ternary/ABATernaryAggregateTest.java | 5 +----
src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R | 4 ++--
src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml | 2 +-
6 files changed, 9 insertions(+), 12 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index bd873ff..1dd5813 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -222,7 +222,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
}
}
}
- };
+ }.reversed();
/**
* Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/main/java/org/apache/sysml/utils/Explain.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Explain.java b/src/main/java/org/apache/sysml/utils/Explain.java
index 5cf0548..450c6e5 100644
--- a/src/main/java/org/apache/sysml/utils/Explain.java
+++ b/src/main/java/org/apache/sysml/utils/Explain.java
@@ -76,7 +76,7 @@ public class Explain
//internal configuration parameters
private static final boolean REPLACE_SPECIAL_CHARACTERS = true;
private static final boolean SHOW_MEM_ABOVE_BUDGET = true;
- private static final boolean SHOW_LITERAL_HOPS = false;
+ private static final boolean SHOW_LITERAL_HOPS = true;
private static final boolean SHOW_DATA_DEPENDENCIES = true;
private static final boolean SHOW_DATA_FLOW_PROPERTIES = true;
@@ -566,7 +566,7 @@ public class Explain
childs.append(" (");
boolean childAdded = false;
for( Hop input : hop.getInput() )
- if( !(input instanceof LiteralOp) ){
+ if( SHOW_LITERAL_HOPS || !(input instanceof LiteralOp) ){
childs.append(childAdded?",":"");
childs.append(input.getHopID());
childAdded = true;
http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
index 47b2f0e..e490750 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
@@ -33,7 +33,7 @@ import org.junit.Assert;
import org.junit.Test;
/**
- * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
+ * Test whether `2*X*3*Y*4*X` successfully rewrites to `Y*(X^2)*24`.
*/
public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedTestBase
{
@@ -96,7 +96,7 @@ public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedT
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
fullRScriptName = HOME + testname + ".R";
- rCmd = getRCmd(inputDir(), expectedDir());
+ rCmd = getRCmd(inputDir(), expectedDir());
double Xsparsity = 0.8, Ysparsity = 0.6;
double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
index 460829d..12525c9 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
@@ -368,15 +368,13 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
if( rtplatform == RUNTIME_PLATFORM.SPARK )
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
- boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES,
- rewritesOldEmult = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+ boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
try {
TestConfiguration config = getTestConfiguration(testname);
loadTestConfiguration(config);
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
- OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
@@ -412,7 +410,6 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
- OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOldEmult;
}
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
index 6d94cc8..fec61ae 100644
--- a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
@@ -28,6 +28,6 @@ library("matrixStats")
X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
-R = X * Y * X;
+R = 2 * X * 3 * Y * 4 * X;
-writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
index 3992403..88f252f 100644
--- a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
@@ -23,6 +23,6 @@
X = read($1);
Y = read($2);
-R = X * Y * X;
+R = 2 * X * 3 * Y * 4 * X;
write(R, $3);
\ No newline at end of file
[04/23] systemml git commit: Correct ordering of e-mult chain
rewrites.
Posted by mb...@apache.org.
Correct ordering of e-mult chain rewrites.
Sorting scalars, vectors, matrices appropriately and by sparsity (when nnz information is available).
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8b832f62
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8b832f62
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8b832f62
Branch: refs/heads/master
Commit: 8b832f624dd23ba0006672c444cf6f0649a6e753
Parents: ff8c836
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 20:48:57 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:21 2017 -0700
----------------------------------------------------------------------
.../apache/sysml/hops/rewrite/RewriteEMult.java | 78 ++++++++++++++++++--
.../functions/misc/RewriteEMultChainTest.java | 7 +-
2 files changed, 74 insertions(+), 11 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/8b832f62/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index 66da6fa..d483a08 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -32,6 +32,7 @@ import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.parser.Expression;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
@@ -125,13 +126,13 @@ public class RewriteEMult extends HopRewriteRule {
}
// sorted contains all leaves, sorted by data type, stripped from their parents
- // Construct left-deep EMult tree
- Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+ // Construct right-deep EMult tree
+ final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
Hop first = constructPower(iterator.next());
for (int i = 1; i < sorted.size(); i++) {
final Hop second = constructPower(iterator.next());
- first = HopRewriteUtils.createBinary(first, second, Hop.OpOp2.MULT);
+ first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
}
return first;
}
@@ -140,14 +141,75 @@ public class RewriteEMult extends HopRewriteRule {
final Hop hop = entry.getKey();
final int cnt = entry.getValue();
assert(cnt >= 1);
- if (cnt == 1)
- return hop;
+ if (cnt == 1) return hop;
return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
}
- private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType)
- .thenComparing(Hop::getName)
- .thenComparingInt(Object::hashCode);
+
+
+ // Order: scalars > row vectors > col vectors >
+ // non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+ // other data types
+ // disambiguate by Hop ID
+ private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
+ @Override
+ public final int compare(Hop o1, Hop o2) {
+ int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
+ if (c != 0) return c;
+
+ // o1 and o2 have the same data type
+ switch (o1.getDataType()) {
+ case SCALAR: return Long.compare(o1.getHopID(), o2.getHopID());
+ case MATRIX:
+ // two matrices; check for vectors
+ if (o1.getDim1() == 1) { // row vector
+ if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
+ return compareBySparsityThenId(o1, o2); // both row vectors
+ } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
+ return -1; // row vectors are the greatest matrices
+ } else if (o1.getDim2() == 1) { // col vector
+ if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
+ return compareBySparsityThenId(o1, o2); // both col vectors
+ } else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
+ return 1; // col vectors greater than non-vectors
+ } else { // both non-vectors
+ return compareBySparsityThenId(o1, o2);
+ }
+ default:
+ return Long.compare(o1.getHopID(), o2.getHopID());
+ }
+ }
+ private int compareBySparsityThenId(Hop o1, Hop o2) {
+ // the hop with more nnz is first; unknown nnz (-1) last
+ int c = Long.compare(o1.getNnz(), o2.getNnz());
+ if (c != 0) return c;
+ return Long.compare(o1.getHopID(), o2.getHopID());
+ }
+ private final int[] orderDataType;
+ {
+ Expression.DataType[] dtValues = Expression.DataType.values();
+ orderDataType = new int[dtValues.length];
+ for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
+ switch(dtValues[i]) {
+ case SCALAR:
+ orderDataType[i] = 4;
+ break;
+ case MATRIX:
+ orderDataType[i] = 3;
+ break;
+ case FRAME:
+ orderDataType[i] = 2;
+ break;
+ case OBJECT:
+ orderDataType[i] = 1;
+ break;
+ case UNKNOWN:
+ orderDataType[i] = 0;
+ break;
+ }
+ }
+ }
+ };
private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
final ArrayList<Hop> parents = child.getParent();
http://git-wip-us.apache.org/repos/asf/systemml/blob/8b832f62/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
index e076c95..18ed55d 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
@@ -99,8 +99,9 @@ public class RewriteEMultChainTest extends AutomatedTestBase
fullRScriptName = HOME + testname + ".R";
rCmd = getRCmd(inputDir(), expectedDir());
- double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.97d, 7);
- double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.9d, 3);
+ double Xsparsity = 0.8, Ysparsity = 0.6;
+ double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+ double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
writeInputMatrixWithMTD("X", X, true);
writeInputMatrixWithMTD("Y", Y, true);
@@ -123,5 +124,5 @@ public class RewriteEMultChainTest extends AutomatedTestBase
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
- }
+ }
}
[22/23] systemml git commit: Ensure the `wumm` new pattern does not
have foreign parents
Posted by mb...@apache.org.
Ensure the `wumm` new pattern does not have foreign parents
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/14bd65a5
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/14bd65a5
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/14bd65a5
Branch: refs/heads/master
Commit: 14bd65a551b227860faba56ed1633db85d7110f2
Parents: 479b9da
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Thu Jul 13 01:41:51 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Thu Jul 13 01:41:51 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/14bd65a5/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 6246270..09b66de 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -1960,7 +1960,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")");
}
- //Pattern 1.5) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
+ //Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
if( !appliedPattern
&& hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT)
&& (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2)
@@ -1974,6 +1974,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
}
if ( HopRewriteUtils.isBinary(nl, OpOp2.MULT)
+ && nl.getParent().size()==1 // ensure no foreign parents
&& HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) //prevent mv
&& nl.getDim2() > 1 //not applied for vector-vector mult
&& nl.getInput().get(0).getDataType() == DataType.MATRIX
[17/23] systemml git commit: Merge branch 'rewrite-emult' into
rewrite-emult2
Posted by mb...@apache.org.
Merge branch 'rewrite-emult' into rewrite-emult2
# Conflicts:
# src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
# src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
# src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
# src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3c4d777a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3c4d777a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3c4d777a
Branch: refs/heads/master
Commit: 3c4d777a5a15ed59681547e91b84c8812d3420fc
Parents: b67f186 999fdfb
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 22:26:59 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 22:37:36 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/rewrite/ProgramRewriter.java | 4 +-
...RewriteElementwiseMultChainOptimization.java | 184 +++++++++++++------
...ElementwiseMultChainOptimizationAllTest.java | 134 ++++++++++++++
...iteElementwiseMultChainOptimizationTest.java | 4 +-
.../functions/misc/RewriteEMultChainOpAll.R | 37 ++++
.../functions/misc/RewriteEMultChainOpAll.dml | 31 ++++
6 files changed, 334 insertions(+), 60 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/3c4d777a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------