You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:35 UTC
[09/23] systemml git commit: Review comments, part 1
Review comments, part 1
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b94557fd
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b94557fd
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b94557fd
Branch: refs/heads/master
Commit: b94557fd2c90c591179cdbf05a32242fadc36448
Parents: d88f867
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 11 00:35:52 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:37 2017 -0700
----------------------------------------------------------------------
.../java/org/apache/sysml/hops/AggUnaryOp.java | 21 +-
.../org/apache/sysml/hops/OptimizerUtils.java | 5 -
.../sysml/hops/rewrite/HopRewriteUtils.java | 3 +-
.../sysml/hops/rewrite/ProgramRewriter.java | 4 +-
.../RewriteAlgebraicSimplificationDynamic.java | 16 +-
.../apache/sysml/hops/rewrite/RewriteEMult.java | 284 -------------------
...RewriteElementwiseMultChainOptimization.java | 281 ++++++++++++++++++
.../functions/misc/RewriteEMultChainTest.java | 127 ---------
...ementwiseMultChainOptimizationChainTest.java | 127 +++++++++
.../ternary/ABATernaryAggregateTest.java | 9 +-
.../functions/misc/ZPackageSuite.java | 1 +
.../functions/ternary/ZPackageSuite.java | 3 +-
12 files changed, 436 insertions(+), 445 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 300a20c..a207831 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -497,7 +497,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
if (binput1.getOp() == OpOp2.POW
&& binput1.getInput().get(1) instanceof LiteralOp) {
LiteralOp lit = (LiteralOp)binput1.getInput().get(1);
- ret = lit.getLongValue() == 3;
+ ret = HopRewriteUtils.getIntValueSafe(lit) == 3;
}
else if (binput1.getOp() == OpOp2.MULT
// As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it.
@@ -640,15 +640,10 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
boolean handled = false;
if (input1.getOp() == OpOp2.POW) {
- switch ((int)((LiteralOp)input12).getLongValue()) {
- case 3:
- in1 = input11.constructLops();
- in2 = in1;
- in3 = in1;
- break;
- default:
- throw new AssertionError("unreachable; only applies to power 3");
- }
+ assert(HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with a power of 3";
+ in1 = input11.constructLops();
+ in2 = in1;
+ in3 = in1;
handled = true;
} else if (input11 instanceof BinaryOp ) {
BinaryOp b11 = (BinaryOp)input11;
@@ -662,8 +657,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
case POW: // A*A*B case
Hop b112 = b11.getInput().get(1);
if ( !(input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT)
- && b112 instanceof LiteralOp
- && ((LiteralOp)b112).getLongValue() == 2) {
+ && HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
in1 = b11.getInput().get(0).constructLops();
in2 = in1;
in3 = input12.constructLops();
@@ -682,8 +676,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
break;
case POW: // A*B*B case
Hop b112 = b12.getInput().get(1);
- if ( b112 instanceof LiteralOp
- && ((LiteralOp)b112).getLongValue() == 2) {
+ if ( HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
in1 = b12.getInput().get(0).constructLops();
in2 = in1;
in3 = input11.constructLops();
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
index 2a76d07..79b7ee6 100644
--- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
@@ -111,11 +111,6 @@ public class OptimizerUtils
public static boolean ALLOW_CONSTANT_FOLDING = true;
public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true;
- /**
- * Enables rewriting chains of element-wise multiplies that contain the same multiplicand more than once, as in
- * `A*B*A ==> (A^2)*B`.
- */
- public static boolean ALLOW_EMULT_CHAIN_REWRITE = true;
public static boolean ALLOW_OPERATOR_FUSION = true;
/**
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 17ac4ec..8f71359 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -251,8 +251,7 @@ public class HopRewriteUtils
* @return replacement
*/
public static Hop replaceHop(final Hop old, final Hop replacement) {
- final ArrayList<Hop> rootParents = old.getParent();
- if (rootParents.isEmpty())
+ if (old.getParent().isEmpty())
return replacement; // new old!
HopRewriteUtils.rewireAllParentChildReferences(old, replacement);
return replacement;
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index b6aab38..1053850 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,8 +96,8 @@ public class ProgramRewriter
_dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() );
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
- if ( OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE )
- _dagRuleSet.add( new RewriteEMult() ); //dependency: cse
+ if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
+ _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 166af2f..91c5972 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -53,6 +53,8 @@ import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
+import static org.apache.sysml.hops.OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+
/**
* Rule: Algebraic Simplifications. Simplifies binary expressions
* in terms of two major purposes: (1) rewrite binary operations
@@ -2051,12 +2053,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
&& hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1
&& !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT)
&& !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT)
- && !(HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) // do not rewrite (A^2)*B
- && hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp // let tak+* handle it
- && ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)
- && !(HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) // do not rewrite B*(A^2)
- && hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp // let tak+* handle it
- && ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2)
+ && ( !ALLOW_SUM_PRODUCT_REWRITES
+ || !( HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) // do not rewrite (A^2)*B
+ && hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp // let tak+* handle it
+ && ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2 ))
+ && ( !ALLOW_SUM_PRODUCT_REWRITES
+ || !( HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) // do not rewrite B*(A^2)
+ && hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp // let tak+* handle it
+ && ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2 ))
)
{
baLeft = hi2.getInput().get(0);
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
deleted file mode 100644
index 5cd1471..0000000
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ /dev/null
@@ -1,284 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.hops.rewrite;
-
-import java.util.ArrayList;
-import java.util.Comparator;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.Set;
-import java.util.SortedMap;
-import java.util.TreeMap;
-
-import org.apache.sysml.hops.BinaryOp;
-import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.HopsException;
-import org.apache.sysml.hops.LiteralOp;
-import org.apache.sysml.parser.Expression;
-
-import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Multiset;
-
-/**
- * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule.
- *
- * Rewrite a chain of element-wise multiply hops that contain identical elements.
- * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
- * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
- *
- * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
- * since foreign parents may rely on the individual results.
- */
-public class RewriteEMult extends HopRewriteRule {
- @Override
- public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
- if( roots == null )
- return null;
- for( int i=0; i<roots.size(); i++ ) {
- Hop h = roots.get(i);
- roots.set(i, rule_RewriteEMult(h));
- }
- return roots;
- }
-
- @Override
- public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException {
- if( root == null )
- return null;
- return rule_RewriteEMult(root);
- }
-
- private static boolean isBinaryMult(final Hop hop) {
- return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT;
- }
-
- private static Hop rule_RewriteEMult(final Hop root) {
- if (root.isVisited())
- return root;
- root.setVisited();
-
- // 1. Find immediate subtree of EMults.
- if (isBinaryMult(root)) {
- final Hop left = root.getInput().get(0), right = root.getInput().get(1);
- final Set<BinaryOp> emults = new HashSet<>();
- final Multiset<Hop> leaves = HashMultiset.create();
- findEMultsAndLeaves((BinaryOp)root, emults, leaves);
-
- // 2. Ensure it is profitable to do a rewrite.
- if (isOptimizable(emults, leaves)) {
- // 3. Check for foreign parents.
- // A foreign parent is a parent of some EMult that is not in the set.
- // Foreign parents destroy correctness of this rewrite.
- final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
- (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
- if (okay) {
- // 4. Construct replacement EMults for the leaves
- final Hop replacement = constructReplacement(leaves);
- if (LOG.isDebugEnabled())
- LOG.debug(String.format(
- "Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
- emults.size(), root.getHopID(), replacement.getHopID()));
-
- // 5. Replace root with replacement
- final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
-
- // 6. Recurse at leaves (no need to repeat the interior emults)
- for (final Hop leaf : leaves.elementSet()) {
- recurseInputs(leaf);
- }
- return newRoot;
- }
- }
- }
- // This rewrite is not applicable to the current root.
- // Try the root's children.
- recurseInputs(root);
- return root;
- }
-
- private static void recurseInputs(final Hop parent) {
- final ArrayList<Hop> inputs = parent.getInput();
- for (int i = 0; i < inputs.size(); i++) {
- final Hop input = inputs.get(i);
- final Hop newInput = rule_RewriteEMult(input);
- inputs.set(i, newInput);
- }
- }
-
- private static Hop constructReplacement(final Multiset<Hop> leaves) {
- // Sort by data type
- final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
- for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
- final Hop h = entry.getElement();
- // unlink parents (the EMults, which we are throwing away)
- h.getParent().clear();
- sorted.put(h, entry.getCount());
- }
- // sorted contains all leaves, sorted by data type, stripped from their parents
-
- // Construct right-deep EMult tree
- final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
- Hop first = constructPower(iterator.next());
-
- for (int i = 1; i < sorted.size(); i++) {
- final Hop second = constructPower(iterator.next());
- first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
- first.setVisited();
- }
- return first;
- }
-
- private static Hop constructPower(Map.Entry<Hop, Integer> entry) {
- final Hop hop = entry.getKey();
- final int cnt = entry.getValue();
- assert(cnt >= 1);
- if (cnt == 1)
- return hop; // don't set this visited... we will visit this next
- Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
- pow.setVisited();
- return pow;
- }
-
- /**
- * A Comparator that orders Hops by their data type, dimention, and sparsity.
- * The order is as follows:
- * scalars > row vectors > col vectors >
- * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
- * other data types.
- * Disambiguate by Hop ID.
- */
- private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
- @Override
- public final int compare(Hop o1, Hop o2) {
- int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
- if (c != 0) return c;
-
- // o1 and o2 have the same data type
- switch (o1.getDataType()) {
- case SCALAR: return Long.compare(o1.getHopID(), o2.getHopID());
- case MATRIX:
- // two matrices; check for vectors
- if (o1.getDim1() == 1) { // row vector
- if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
- return compareBySparsityThenId(o1, o2); // both row vectors
- } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
- return -1; // row vectors are the greatest matrices
- } else if (o1.getDim2() == 1) { // col vector
- if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
- return compareBySparsityThenId(o1, o2); // both col vectors
- } else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
- return 1; // col vectors greater than non-vectors
- } else { // both non-vectors
- return compareBySparsityThenId(o1, o2);
- }
- default:
- return Long.compare(o1.getHopID(), o2.getHopID());
- }
- }
- private int compareBySparsityThenId(Hop o1, Hop o2) {
- // the hop with more nnz is first; unknown nnz (-1) last
- int c = Long.compare(o1.getNnz(), o2.getNnz());
- if (c != 0) return c;
- return Long.compare(o1.getHopID(), o2.getHopID());
- }
- private final int[] orderDataType;
- {
- Expression.DataType[] dtValues = Expression.DataType.values();
- orderDataType = new int[dtValues.length];
- for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
- switch(dtValues[i]) {
- case SCALAR:
- orderDataType[i] = 4;
- break;
- case MATRIX:
- orderDataType[i] = 3;
- break;
- case FRAME:
- orderDataType[i] = 2;
- break;
- case OBJECT:
- orderDataType[i] = 1;
- break;
- case UNKNOWN:
- orderDataType[i] = 0;
- break;
- }
- }
- }
- };
-
- /**
- * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
- * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
- * @param child An interior emult hop in the emult chain dag.
- * @return Whether this interior emult or any child emult has a foreign parent.
- */
- private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
- final ArrayList<Hop> parents = child.getParent();
- if (parents.size() > 1)
- for (final Hop parent : parents)
- //noinspection SuspiciousMethodCalls
- if (!emults.contains(parent))
- return false;
- // child does not have foreign parents
-
- final ArrayList<Hop> inputs = child.getInput();
- final Hop left = inputs.get(0), right = inputs.get(1);
- return (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
- (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
- }
-
- /**
- * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively.
- * @param root Root of sub-dag
- * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
- * @param leaves Out parameter. The multiset of multiplicands in the emult chain.
- */
- private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
- // Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
- emults.add(root);
-
- final ArrayList<Hop> inputs = root.getInput();
- final Hop left = inputs.get(0), right = inputs.get(1);
-
- if (isBinaryMult(left)) findEMultsAndLeaves((BinaryOp) left, emults, leaves);
- else leaves.add(left);
-
- if (isBinaryMult(right)) findEMultsAndLeaves((BinaryOp) right, emults, leaves);
- else leaves.add(right);
- }
-
- /**
- * Only optimize a subtree of emults if there are at least two emults.
- * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
- * @param leaves The multiset of multiplicands in the emult chain.
- * @return If the multiset is worth optimizing.
- */
- private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
- // Old criterion: there should be at least one repeated leaf
-// for (Multiset.Entry<Hop> hopEntry : leaves.entrySet()) {
-// if (hopEntry.getCount() > 1)
-// return true;
-// }
-// return false;
- return emults.size() >= 2;
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
new file mode 100644
index 0000000..bd873ff
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -0,0 +1,281 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.parser.Expression;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+
+/**
+ * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule.
+ *
+ * Rewrite a chain of element-wise multiply hops that contain identical elements.
+ * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
+ * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
+ *
+ * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
+ * since foreign parents may rely on the individual results.
+ */
+public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
+ @Override
+ public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
+ if( roots == null )
+ return null;
+ for( int i=0; i<roots.size(); i++ ) {
+ Hop h = roots.get(i);
+ roots.set(i, rule_RewriteEMult(h));
+ }
+ return roots;
+ }
+
+ @Override
+ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException {
+ if( root == null )
+ return null;
+ return rule_RewriteEMult(root);
+ }
+
+ private static boolean isBinaryMult(final Hop hop) {
+ return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT;
+ }
+
+ private static Hop rule_RewriteEMult(final Hop root) {
+ if (root.isVisited())
+ return root;
+ root.setVisited();
+
+ // 1. Find immediate subtree of EMults.
+ if (isBinaryMult(root)) {
+ final Hop left = root.getInput().get(0), right = root.getInput().get(1);
+ final Set<BinaryOp> emults = new HashSet<>();
+ final Multiset<Hop> leaves = HashMultiset.create();
+ findEMultsAndLeaves((BinaryOp)root, emults, leaves);
+
+ // 2. Ensure it is profitable to do a rewrite.
+ if (isOptimizable(emults, leaves)) {
+ // 3. Check for foreign parents.
+ // A foreign parent is a parent of some EMult that is not in the set.
+ // Foreign parents destroy correctness of this rewrite.
+ final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
+ (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
+ if (okay) {
+ // 4. Construct replacement EMults for the leaves
+ final Hop replacement = constructReplacement(leaves);
+ if (LOG.isDebugEnabled())
+ LOG.debug(String.format(
+ "Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
+ emults.size(), root.getHopID(), replacement.getHopID()));
+
+ // 5. Replace root with replacement
+ final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
+
+ // 6. Recurse at leaves (no need to repeat the interior emults)
+ for (final Hop leaf : leaves.elementSet()) {
+ recurseInputs(leaf);
+ }
+ return newRoot;
+ }
+ }
+ }
+ // This rewrite is not applicable to the current root.
+ // Try the root's children.
+ recurseInputs(root);
+ return root;
+ }
+
+ private static void recurseInputs(final Hop parent) {
+ final ArrayList<Hop> inputs = parent.getInput();
+ for (int i = 0; i < inputs.size(); i++) {
+ final Hop input = inputs.get(i);
+ final Hop newInput = rule_RewriteEMult(input);
+ inputs.set(i, newInput);
+ }
+ }
+
+ private static Hop constructReplacement(final Multiset<Hop> leaves) {
+ // Sort by data type
+ final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
+ for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
+ final Hop h = entry.getElement();
+ // unlink parents (the EMults, which we are throwing away)
+ h.getParent().clear();
+ sorted.put(h, entry.getCount());
+ }
+ // sorted contains all leaves, sorted by data type, stripped from their parents
+
+ // Construct right-deep EMult tree
+ final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+ Hop first = constructPower(iterator.next());
+
+ for (int i = 1; i < sorted.size(); i++) {
+ final Hop second = constructPower(iterator.next());
+ first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
+ first.setVisited();
+ }
+ return first;
+ }
+
+ private static Hop constructPower(Map.Entry<Hop, Integer> entry) {
+ final Hop hop = entry.getKey();
+ final int cnt = entry.getValue();
+ assert(cnt >= 1);
+ if (cnt == 1)
+ return hop; // don't set this visited... we will visit this next
+ Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+ pow.setVisited();
+ return pow;
+ }
+
+ /**
+ * A Comparator that orders Hops by their data type, dimention, and sparsity.
+ * The order is as follows:
+ * scalars > row vectors > col vectors >
+ * non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+ * other data types.
+ * Disambiguate by Hop ID.
+ */
+ private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
+ @Override
+ public final int compare(Hop o1, Hop o2) {
+ int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
+ if (c != 0) return c;
+
+ // o1 and o2 have the same data type
+ switch (o1.getDataType()) {
+ case MATRIX:
+ // two matrices; check for vectors
+ if (o1.getDim1() == 1) { // row vector
+ if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
+ return compareBySparsityThenId(o1, o2); // both row vectors
+ } else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
+ return -1; // row vectors are the greatest matrices
+ } else if (o1.getDim2() == 1) { // col vector
+ if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
+ return compareBySparsityThenId(o1, o2); // both col vectors
+ } else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
+ return 1; // col vectors greater than non-vectors
+ } else { // both non-vectors
+ return compareBySparsityThenId(o1, o2);
+ }
+ default:
+ return Long.compare(o1.getHopID(), o2.getHopID());
+ }
+ }
+ private int compareBySparsityThenId(Hop o1, Hop o2) {
+ // the hop with more nnz is first; unknown nnz (-1) last
+ int c = Long.compare(o1.getNnz(), o2.getNnz());
+ if (c != 0) return c;
+ return Long.compare(o1.getHopID(), o2.getHopID());
+ }
+ private final int[] orderDataType;
+ {
+ Expression.DataType[] dtValues = Expression.DataType.values();
+ orderDataType = new int[dtValues.length];
+ for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
+ switch(dtValues[i]) {
+ case SCALAR:
+ orderDataType[i] = 4;
+ break;
+ case MATRIX:
+ orderDataType[i] = 3;
+ break;
+ case FRAME:
+ orderDataType[i] = 2;
+ break;
+ case OBJECT:
+ orderDataType[i] = 1;
+ break;
+ case UNKNOWN:
+ orderDataType[i] = 0;
+ break;
+ }
+ }
+ }
+ };
+
+ /**
+ * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
+ * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+ * @param child An interior emult hop in the emult chain dag.
+ * @return Whether this interior emult or any child emult has a foreign parent.
+ */
+ private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
+ final ArrayList<Hop> parents = child.getParent();
+ if (parents.size() > 1)
+ for (final Hop parent : parents)
+ //noinspection SuspiciousMethodCalls (for Intellij, which checks when
+ if (!emults.contains(parent))
+ return false;
+ // child does not have foreign parents
+
+ final ArrayList<Hop> inputs = child.getInput();
+ final Hop left = inputs.get(0), right = inputs.get(1);
+ return (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
+ (!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
+ }
+
+ /**
+ * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively.
+ * @param root Root of sub-dag
+ * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
+ * @param leaves Out parameter. The multiset of multiplicands in the emult chain.
+ */
+ private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+ // Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
+ emults.add(root);
+
+ final ArrayList<Hop> inputs = root.getInput();
+ final Hop left = inputs.get(0), right = inputs.get(1);
+
+ if (isBinaryMult(left))
+ findEMultsAndLeaves((BinaryOp) left, emults, leaves);
+ else
+ leaves.add(left);
+
+ if (isBinaryMult(right))
+ findEMultsAndLeaves((BinaryOp) right, emults, leaves);
+ else
+ leaves.add(right);
+ }
+
+ /**
+ * Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
+ * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+ * @param leaves The multiset of multiplicands in the emult chain.
+ * @return If the multiset is worth optimizing.
+ */
+ private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+ return emults.size() >= 2;
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
deleted file mode 100644
index 85dbea4..0000000
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.test.integration.functions.misc;
-
-import java.util.HashMap;
-
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.lops.LopProperties.ExecType;
-import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.apache.sysml.test.integration.TestConfiguration;
-import org.apache.sysml.test.utils.TestUtils;
-import org.junit.Assert;
-import org.junit.Test;
-
-/**
- * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
- */
-public class RewriteEMultChainTest extends AutomatedTestBase
-{
- private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
- private static final String TEST_DIR = "functions/misc/";
- private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEMultChainTest.class.getSimpleName() + "/";
-
- private static final int rows = 123;
- private static final int cols = 321;
- private static final double eps = Math.pow(10, -10);
-
- @Override
- public void setUp() {
- TestUtils.clearAssertionInformation();
- addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
- }
-
- @Test
- public void testMatrixMultChainOptNoRewritesCP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
- }
-
- @Test
- public void testMatrixMultChainOptNoRewritesSP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
- }
-
- @Test
- public void testMatrixMultChainOptRewritesCP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
- }
-
- @Test
- public void testMatrixMultChainOptRewritesSP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
- }
-
- private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
- {
- RUNTIME_PLATFORM platformOld = rtplatform;
- switch( et ){
- case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
- case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
- default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
- }
-
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == RUNTIME_PLATFORM.SPARK )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
- boolean rewritesOld = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
- OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
-
- try
- {
- TestConfiguration config = getTestConfiguration(testname);
- loadTestConfiguration(config);
-
- String HOME = SCRIPT_DIR + TEST_DIR;
- fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
- fullRScriptName = HOME + testname + ".R";
- rCmd = getRCmd(inputDir(), expectedDir());
-
- double Xsparsity = 0.8, Ysparsity = 0.6;
- double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
- double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
- writeInputMatrixWithMTD("X", X, true);
- writeInputMatrixWithMTD("Y", Y, true);
-
- //execute tests
- runTest(true, false, null, -1);
- runRScript(true);
-
- //compare matrices
- HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
- HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
- TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
-
- //check for presence of power operator, if we did a rewrite
- if( rewrites ) {
- Assert.assertTrue(heavyHittersContainsSubString("^2"));
- }
- }
- finally {
- OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOld;
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
new file mode 100644
index 0000000..47b2f0e
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
+ */
+public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
+ private static final String TEST_DIR = "functions/misc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationChainTest.class.getSimpleName() + "/";
+
+ private static final int rows = 123;
+ private static final int cols = 321;
+ private static final double eps = Math.pow(10, -10);
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+ }
+
+ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
+ {
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( et ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+
+ try
+ {
+ TestConfiguration config = getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ double Xsparsity = 0.8, Ysparsity = 0.6;
+ double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+ double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
+ writeInputMatrixWithMTD("X", X, true);
+ writeInputMatrixWithMTD("Y", Y, true);
+
+ //execute tests
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
+ TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+
+ //check for presence of power operator, if we did a rewrite
+ if( rewrites ) {
+ Assert.assertTrue(heavyHittersContainsSubString("^2"));
+ }
+ }
+ finally {
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
index 1829bf0..460829d 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
@@ -24,6 +24,7 @@ import java.util.HashMap;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.rewrite.RewriteElementwiseMultChainOptimization;
import org.apache.sysml.lops.LopProperties.ExecType;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
@@ -36,7 +37,7 @@ import org.junit.Test;
/**
* Similar to {@link TernaryAggregateTest} except that it tests `sum(A*B*A)`.
- * Checks compatibility with {@link org.apache.sysml.hops.rewrite.RewriteEMult}.
+ * Checks compatibility with {@link RewriteElementwiseMultChainOptimization}.
*/
public class ABATernaryAggregateTest extends AutomatedTestBase
{
@@ -368,14 +369,14 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES,
- rewritesOldEmult = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
+ rewritesOldEmult = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
try {
TestConfiguration config = getTestConfiguration(testname);
loadTestConfiguration(config);
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
- OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
@@ -411,7 +412,7 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
- OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOldEmult;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOldEmult;
}
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index e352e6d..deea784 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -50,6 +50,7 @@ import org.junit.runners.Suite;
ReadAfterWriteTest.class,
RewriteCSETransposeScalarTest.class,
RewriteCTableToRExpandTest.class,
+ RewriteElementwiseMultChainOptimizationChainTest.class,
RewriteEliminateAggregatesTest.class,
RewriteFuseBinaryOpChainTest.class,
RewriteFusedRandTest.class,
http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
index 784177d..ee14359 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
@@ -22,10 +22,11 @@ package org.apache.sysml.test.integration.functions.ternary;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
-/** Group together the tests in this package into a single suite so that the Maven build
+/* Group together the tests in this package into a single suite so that the Maven build
* won't run two of them at once. */
@RunWith(Suite.class)
@Suite.SuiteClasses({
+ ABATernaryAggregateTest.class,
CentralMomentWeightsTest.class,
CovarianceWeightsTest.class,
CTableMatrixIgnoreZerosTest.class,