You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:30 UTC
[04/23] systemml git commit: Correct ordering of e-mult chain
rewrites.
Correct ordering of e-mult chain rewrites.
Sorting scalars, vectors, matrices appropriately and by sparsity (when nnz information is available).
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8b832f62
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8b832f62
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8b832f62
Branch: refs/heads/master
Commit: 8b832f624dd23ba0006672c444cf6f0649a6e753
Parents: ff8c836
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 20:48:57 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:21 2017 -0700
----------------------------------------------------------------------
.../apache/sysml/hops/rewrite/RewriteEMult.java | 78 ++++++++++++++++++--
.../functions/misc/RewriteEMultChainTest.java | 7 +-
2 files changed, 74 insertions(+), 11 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/8b832f62/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index 66da6fa..d483a08 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -32,6 +32,7 @@ import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.parser.Expression;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
@@ -125,13 +126,13 @@ public class RewriteEMult extends HopRewriteRule {
}
// sorted contains all leaves, sorted by data type, stripped from their parents
- // Construct left-deep EMult tree
- Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+ // Construct right-deep EMult tree
+ final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
Hop first = constructPower(iterator.next());
for (int i = 1; i < sorted.size(); i++) {
final Hop second = constructPower(iterator.next());
- first = HopRewriteUtils.createBinary(first, second, Hop.OpOp2.MULT);
+ first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
}
return first;
}
@@ -140,14 +141,75 @@ public class RewriteEMult extends HopRewriteRule {
final Hop hop = entry.getKey();
final int cnt = entry.getValue();
assert(cnt >= 1);
- if (cnt == 1)
- return hop;
+ if (cnt == 1) return hop;
return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
}
- private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType)
- .thenComparing(Hop::getName)
- .thenComparingInt(Object::hashCode);
+
+
+ // Order: scalars > row vectors > col vectors >
+ // non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+ // other data types
+ // disambiguate by Hop ID
+ private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
+ @Override
+ public final int compare(Hop o1, Hop o2) {
+ int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
+ if (c != 0) return c;
+
+ // o1 and o2 have the same data type
+ switch (o1.getDataType()) {
+ case SCALAR: return Long.compare(o1.getHopID(), o2.getHopID());
+ case MATRIX:
+ // two matrices; check for vectors
+ if (o1.getDim1() == 1) { // row vector
+ if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
+ return compareBySparsityThenId(o1, o2); // both row vectors
+ } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
+ return -1; // row vectors are the greatest matrices
+ } else if (o1.getDim2() == 1) { // col vector
+ if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
+ return compareBySparsityThenId(o1, o2); // both col vectors
+ } else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
+ return 1; // col vectors greater than non-vectors
+ } else { // both non-vectors
+ return compareBySparsityThenId(o1, o2);
+ }
+ default:
+ return Long.compare(o1.getHopID(), o2.getHopID());
+ }
+ }
+ private int compareBySparsityThenId(Hop o1, Hop o2) {
+ // the hop with more nnz is first; unknown nnz (-1) last
+ int c = Long.compare(o1.getNnz(), o2.getNnz());
+ if (c != 0) return c;
+ return Long.compare(o1.getHopID(), o2.getHopID());
+ }
+ private final int[] orderDataType;
+ {
+ Expression.DataType[] dtValues = Expression.DataType.values();
+ orderDataType = new int[dtValues.length];
+ for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
+ switch(dtValues[i]) {
+ case SCALAR:
+ orderDataType[i] = 4;
+ break;
+ case MATRIX:
+ orderDataType[i] = 3;
+ break;
+ case FRAME:
+ orderDataType[i] = 2;
+ break;
+ case OBJECT:
+ orderDataType[i] = 1;
+ break;
+ case UNKNOWN:
+ orderDataType[i] = 0;
+ break;
+ }
+ }
+ }
+ };
private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
final ArrayList<Hop> parents = child.getParent();
http://git-wip-us.apache.org/repos/asf/systemml/blob/8b832f62/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
index e076c95..18ed55d 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
@@ -99,8 +99,9 @@ public class RewriteEMultChainTest extends AutomatedTestBase
fullRScriptName = HOME + testname + ".R";
rCmd = getRCmd(inputDir(), expectedDir());
- double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.97d, 7);
- double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.9d, 3);
+ double Xsparsity = 0.8, Ysparsity = 0.6;
+ double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+ double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
writeInputMatrixWithMTD("X", X, true);
writeInputMatrixWithMTD("Y", Y, true);
@@ -123,5 +124,5 @@ public class RewriteEMultChainTest extends AutomatedTestBase
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
- }
+ }
}