You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 06:27:34 UTC
systemml git commit: [SYSTEMML-1663] Fix and enable rewrite
element-wise multiply chains [Forced Update!]
Repository: systemml
Updated Branches:
refs/heads/master 85e3a9631 -> eca9dbbb8 (forced update)
[SYSTEMML-1663] Fix and enable rewrite element-wise multiply chains
Groups together types of e-wise multiply inputs.
Comprehensive test on all different kinds of objects multiplied together.
The new order of element-wise multiply chains is as follows:
<pre>
(((unknown * object * frame) * ([least-nnz-matrix * matrix] * most-nnz-matrix))
* ([least-nnz-row-vector * row-vector] * most-nnz-row-vector))
* ([[scalars * least-nnz-col-vector] * col-vector] * most-nnz-col-vector)
</pre>
Moves to dynamic rewrites. Do not rewrite if top-level dims unknown.
Adds new 'wumm' pattern to pick up element-wise multiply rewrite.
The new pattern recognizes when there is a '*2' or '2*' outside 'W*(U%*%t(V))'.
Closes #567.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/eca9dbbb
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/eca9dbbb
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/eca9dbbb
Branch: refs/heads/master
Commit: eca9dbbb85971af688e81c9254538c53fc429b30
Parents: 1b3dff0
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jul 14 23:08:46 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Fri Jul 14 23:08:46 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/rewrite/ProgramRewriter.java | 8 +-
.../RewriteAlgebraicSimplificationDynamic.java | 45 ++++-
...RewriteElementwiseMultChainOptimization.java | 180 +++++++++++++------
...ElementwiseMultChainOptimizationAllTest.java | 134 ++++++++++++++
.../functions/misc/RewriteEMultChainOpAll.R | 37 ++++
.../functions/misc/RewriteEMultChainOpAll.dml | 31 ++++
6 files changed, 376 insertions(+), 59 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 92d31c2..7c4f861 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -54,7 +54,7 @@ public class ProgramRewriter
private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName());
//internal local debug level
- private static final boolean LDEBUG = false;
+ private static final boolean LDEBUG = false;
private static final boolean CHECK = false;
private ArrayList<HopRewriteRule> _dagRuleSet = null;
@@ -96,8 +96,6 @@ public class ProgramRewriter
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
- //if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
- // _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
@@ -125,7 +123,9 @@ public class ProgramRewriter
// DYNAMIC REWRITES (which do require size information)
if( dynamicRewrites )
{
- _dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse
+ _dagRuleSet.add( new RewriteMatrixMultChainOptimization() ); //dependency: cse
+ if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
+ _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
{
http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 8cd71f4..09b66de 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -29,11 +29,11 @@ import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataGenOp;
import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.QuaternaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.DataGenMethod;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.Hop.OpOp3;
import org.apache.sysml.hops.Hop.OpOp4;
import org.apache.sysml.hops.Hop.ParamBuiltinOp;
@@ -44,7 +44,7 @@ import org.apache.sysml.hops.LeftIndexingOp;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.ParameterizedBuiltinOp;
-import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.QuaternaryOp;
import org.apache.sysml.hops.ReorgOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
@@ -1959,6 +1959,47 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
appliedPattern = true;
LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")");
}
+
+ //Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
+ if( !appliedPattern
+ && hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT)
+ && (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2)
+ || HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2)))
+ {
+ final Hop nl; // non-literal
+ if( hi.getInput().get(0) instanceof LiteralOp ) {
+ nl = hi.getInput().get(1);
+ } else {
+ nl = hi.getInput().get(0);
+ }
+
+ if ( HopRewriteUtils.isBinary(nl, OpOp2.MULT)
+ && nl.getParent().size()==1 // ensure no foreign parents
+ && HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) //prevent mv
+ && nl.getDim2() > 1 //not applied for vector-vector mult
+ && nl.getInput().get(0).getDataType() == DataType.MATRIX
+ && nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock()
+ && HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1))
+ && (((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
+ && HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0),true) )
+ {
+ final Hop W = nl.getInput().get(0);
+ final Hop U = nl.getInput().get(1).getInput().get(0);
+ Hop V = nl.getInput().get(1).getInput().get(1);
+ if( !HopRewriteUtils.isTransposeOperation(V) )
+ V = HopRewriteUtils.createTranspose(V);
+ else
+ V = V.getInput().get(0);
+
+ hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+ OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
+ hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
+ hnew.refreshSizeInformation();
+
+ appliedPattern = true;
+ LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line "+hi.getBeginLine()+")");
+ }
+ }
//Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
if( !appliedPattern
http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index c2c3b11..2e411f6 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -26,8 +26,8 @@ import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
-import java.util.SortedMap;
-import java.util.TreeMap;
+import java.util.SortedSet;
+import java.util.TreeSet;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
@@ -44,6 +44,15 @@ import org.apache.sysml.parser.Expression;
*
* Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
* since foreign parents may rely on the individual results.
+ * Does not perform rewrites on an element-wise multiply if its dimensions are unknown.
+ *
+ * The new order of element-wise multiply chains is as follows:
+ * <pre>
+ * (((unknown * object * frame) * ([least-nnz-matrix * matrix] * most-nnz-matrix))
+ * * ([least-nnz-row-vector * row-vector] * most-nnz-row-vector))
+ * * ([[scalars * least-nnz-col-vector] * col-vector] * most-nnz-col-vector)
+ * </pre>
+ * Identical elements are replaced with powers.
*/
public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
@Override
@@ -73,15 +82,18 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
return root;
root.setVisited();
- // 1. Find immediate subtree of EMults.
- if (isBinaryMult(root)) {
+ // 1. Find immediate subtree of EMults. Check dimsKnown.
+ if (isBinaryMult(root) && root.dimsKnown()) {
final Hop left = root.getInput().get(0), right = root.getInput().get(1);
+ // The set of BinaryOp element-wise multiply hops in the emult chain.
final Set<BinaryOp> emults = new HashSet<>();
+ // The multiset of multiplicands in the emult chain.
final Map<Hop, Integer> leaves = new HashMap<>(); // poor man's HashMultiset
findEMultsAndLeaves((BinaryOp)root, emults, leaves);
// 2. Ensure it is profitable to do a rewrite.
- if (isOptimizable(emults, leaves)) {
+ // Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
+ if (emults.size() >= 2) {
// 3. Check for foreign parents.
// A foreign parent is a parent of some EMult that is not in the set.
// Foreign parents destroy correctness of this rewrite.
@@ -123,38 +135,110 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
private static Hop constructReplacement(final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
// Sort by data type
- final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
+ final SortedSet<Hop> sorted = new TreeSet<>(compareByDataType);
for (final Map.Entry<Hop, Integer> entry : leaves.entrySet()) {
final Hop h = entry.getKey();
// unlink parents that are in the emult set(we are throwing them away)
// keep other parents
h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent));
- sorted.put(h, entry.getValue());
+ sorted.add(constructPower(h, entry.getValue()));
}
// sorted contains all leaves, sorted by data type, stripped from their parents
// Construct right-deep EMult tree
- // TODO compile binary outer mult for transition from row and column vectors to matrices
- // TODO compile subtree for column vectors to avoid blow-up of intermediates on row-col vector transition
- final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
- Hop first = constructPower(iterator.next());
-
- for (int i = 1; i < sorted.size(); i++) {
- final Hop second = constructPower(iterator.next());
- first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
- first.setVisited();
+ final Iterator<Hop> iterator = sorted.iterator();
+
+ Hop next = iterator.hasNext() ? iterator.next() : null;
+ Hop colVectorsScalars = null;
+ while(next != null &&
+ (next.getDataType() == Expression.DataType.SCALAR
+ || next.getDataType() == Expression.DataType.MATRIX && next.getDim2() == 1))
+ {
+ if( colVectorsScalars == null )
+ colVectorsScalars = next;
+ else {
+ colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT);
+ colVectorsScalars.setVisited();
+ }
+ next = iterator.hasNext() ? iterator.next() : null;
+ }
+ // next is not processed and is either null or past col vectors
+
+ Hop rowVectors = null;
+ while(next != null &&
+ (next.getDataType() == Expression.DataType.MATRIX && next.getDim1() == 1))
+ {
+ if( rowVectors == null )
+ rowVectors = next;
+ else {
+ rowVectors = HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT);
+ rowVectors.setVisited();
+ }
+ next = iterator.hasNext() ? iterator.next() : null;
+ }
+ // next is not processed and is either null or past row vectors
+
+ Hop matrices = null;
+ while(next != null &&
+ (next.getDataType() == Expression.DataType.MATRIX))
+ {
+ if( matrices == null )
+ matrices = next;
+ else {
+ matrices = HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT);
+ matrices.setVisited();
+ }
+ next = iterator.hasNext() ? iterator.next() : null;
+ }
+ // next is not processed and is either null or past matrices
+
+ Hop other = null;
+ while(next != null)
+ {
+ if( other == null )
+ other = next;
+ else {
+ other = HopRewriteUtils.createBinary(other, next, Hop.OpOp2.MULT);
+ other.setVisited();
+ }
+ next = iterator.hasNext() ? iterator.next() : null;
+ }
+ // finished
+
+ // ((other * matrices) * rowVectors) * colVectorsScalars
+ Hop top = null;
+ if( other == null && matrices != null )
+ top = matrices;
+ else if( other != null && matrices == null )
+ top = other;
+ else if( other != null ) { //matrices != null
+ top = HopRewriteUtils.createBinary(other, matrices, Hop.OpOp2.MULT);
+ top.setVisited();
+ }
+
+ if( top == null && rowVectors != null )
+ top = rowVectors;
+ else if( rowVectors != null ) { //top != null
+ top = HopRewriteUtils.createBinary(top, rowVectors, Hop.OpOp2.MULT);
+ top.setVisited();
+ }
+
+ if( top == null && colVectorsScalars != null )
+ top = colVectorsScalars;
+ else if( colVectorsScalars != null ) { //top != null
+ top = HopRewriteUtils.createBinary(top, colVectorsScalars, Hop.OpOp2.MULT);
+ top.setVisited();
}
- return first;
+
+ return top;
}
- private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
- final Hop hop = entry.getKey();
- final int cnt = entry.getValue();
+ private static Hop constructPower(final Hop hop, final int cnt) {
assert(cnt >= 1);
hop.setVisited(); // we will visit the leaves' children next
if (cnt == 1)
return hop;
- Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+ final Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
pow.setVisited();
return pow;
}
@@ -162,8 +246,8 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
/**
* A Comparator that orders Hops by their data type, dimension, and sparsity.
* The order is as follows:
- * scalars > row vectors > col vectors >
- * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+ * scalars < col vectors < row vectors <
+ * non-vector matrices ordered by sparsity (higher nnz last, unknown sparsity last) >
* other data types.
* Disambiguate by Hop ID.
*/
@@ -174,33 +258,33 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
{
for (int i = 0, valuesLength = Expression.DataType.values().length; i < valuesLength; i++)
switch(Expression.DataType.values()[i]) {
- case SCALAR: orderDataType[i] = 4; break;
- case MATRIX: orderDataType[i] = 3; break;
+ case SCALAR: orderDataType[i] = 0; break;
+ case MATRIX: orderDataType[i] = 1; break;
case FRAME: orderDataType[i] = 2; break;
- case OBJECT: orderDataType[i] = 1; break;
- case UNKNOWN:orderDataType[i] = 0; break;
+ case OBJECT: orderDataType[i] = 3; break;
+ case UNKNOWN:orderDataType[i] = 4; break;
}
}
@Override
- public final int compare(Hop o1, Hop o2) {
- int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
+ public final int compare(final Hop o1, final Hop o2) {
+ final int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
if (c != 0) return c;
// o1 and o2 have the same data type
switch (o1.getDataType()) {
case MATRIX:
// two matrices; check for vectors
- if (o1.getDim1() == 1) { // row vector
- if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
- return compareBySparsityThenId(o1, o2); // both row vectors
- } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
- return -1; // row vectors are the greatest matrices
- } else if (o1.getDim2() == 1) { // col vector
- if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
- return compareBySparsityThenId(o1, o2); // both col vectors
+ if (o1.getDim2() == 1) { // col vector
+ if (o2.getDim2() != 1) return -1; // col vectors are greatest of matrices
+ return compareBySparsityThenId(o1, o2); // both col vectors
} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
- return -1; // col vectors greater than non-vectors
+ return 1; // col vectors are the greatest matrices
+ } else if (o1.getDim1() == 1) { // row vector
+ if (o2.getDim1() != 1) return -1; // row vectors greater than non-vectors
+ return compareBySparsityThenId(o1, o2); // both row vectors
+ } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
+ return 1; // row vectors greater than non-vectors
} else { // both non-vectors
return compareBySparsityThenId(o1, o2);
}
@@ -208,13 +292,13 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
return Long.compare(o1.getHopID(), o2.getHopID());
}
}
- private int compareBySparsityThenId(Hop o1, Hop o2) {
+ private int compareBySparsityThenId(final Hop o1, final Hop o2) {
// the hop with more nnz is first; unknown nnz (-1) last
- int c = Long.compare(o1.getNnz(), o2.getNnz());
- if (c != 0) return c;
+ final int c = Long.compare(o1.getNnz(), o2.getNnz());
+ if (c != 0) return -c;
return Long.compare(o1.getHopID(), o2.getHopID());
}
- }.reversed();
+ };
/**
* Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
@@ -242,8 +326,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
* @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
* @param leaves Out parameter. The multiset of multiplicands in the emult chain.
*/
- private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults,
- final Map<Hop, Integer> leaves) {
+ private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
emults.add(root);
@@ -268,13 +351,4 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
map.put(k, map.getOrDefault(k, 0) + 1);
}
- /**
- * Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
- * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
- * @param leaves The multiset of multiplicands in the emult chain.
- * @return If the multiset is worth optimizing.
- */
- private static boolean isOptimizable(final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
- return emults.size() >= 2;
- }
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
new file mode 100644
index 0000000..ba5c78d
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test rewriting `2*X*3*v*5*w*4*z*5*Y*2*v*2*X`, where `v` and `z` are row vectors and `w` is a column vector,
+ * successfully rewrites to `Y*(X^2)*z*(v^2)*w*2400`.
+ */
+public class RewriteElementwiseMultChainOptimizationAllTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME1 = "RewriteEMultChainOpAll";
+ private static final String TEST_DIR = "functions/misc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationAllTest.class.getSimpleName() + "/";
+
+ private static final int rows = 123;
+ private static final int cols = 321;
+ private static final double eps = Math.pow(10, -10);
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+ }
+
+ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
+ {
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( et ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+
+ try
+ {
+ TestConfiguration config = getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), input("v"), input("z"), input("w"), output("R") };
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ double Xsparsity = 0.8, Ysparsity = 0.6;
+ double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+ double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
+ double[][] z = getRandomMatrix(1, cols, -1, 1, Ysparsity, 5);
+ double[][] v = getRandomMatrix(1, cols, -1, 1, Xsparsity, 4);
+ double[][] w = getRandomMatrix(rows, 1, -1, 1, Ysparsity, 6);
+ writeInputMatrixWithMTD("X", X, true);
+ writeInputMatrixWithMTD("Y", Y, true);
+ writeInputMatrixWithMTD("z", z, true);
+ writeInputMatrixWithMTD("v", v, true);
+ writeInputMatrixWithMTD("w", w, true);
+
+ //execute tests
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
+ TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+
+ //check for presence of power operator, if we did a rewrite
+ if( rewrites ) {
+ Assert.assertTrue(heavyHittersContainsSubString("^2"));
+ }
+ }
+ finally {
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
new file mode 100644
index 0000000..20f76c2
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+# args[1]=""
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
+v = as.matrix(readMM(paste(args[1], "v.mtx", sep="")))
+z = as.matrix(readMM(paste(args[1], "z.mtx", sep="")))
+w = as.matrix(readMM(paste(args[1], "w.mtx", sep="")))
+
+R = 2* X *3* X *5* Y *4*5*2*2* (matrix(1,length(w),1)%*%z) * (matrix(1,length(w),1)%*%v)^2 * (w%*%matrix(1,1,length(v)))
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/eca9dbbb/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
new file mode 100644
index 0000000..90f9242
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1);
+Y = read($2);
+v = read($3);
+z = read($4);
+w = read($5);
+
+R = 2* X *3* v *5* w *4* z *5* Y *2* v *2* X
+
+write(R, $6);
\ No newline at end of file